From bdb5d8e19258f5d447f9fd235e9d77e0e9527dc5 Mon Sep 17 00:00:00 2001 From: Patrick Stevens Date: Sat, 17 Jun 2023 22:14:37 +0100 Subject: [PATCH] One-hot encoding (#30) --- little_learner_app/src/iris.rs | 79 ++++++++++++++++++++++++---------- little_learner_app/src/main.rs | 13 +++++- 2 files changed, 67 insertions(+), 25 deletions(-) diff --git a/little_learner_app/src/iris.rs b/little_learner_app/src/iris.rs index 449a9cb..ea00024 100644 --- a/little_learner_app/src/iris.rs +++ b/little_learner_app/src/iris.rs @@ -1,14 +1,18 @@ use csv::ReaderBuilder; +use little_learner::auto_diff::RankedDifferentiable; +use little_learner::scalar::Scalar; +use little_learner::traits::{One, Zero}; +use std::fmt::Debug; use std::io::Cursor; use std::str::FromStr; const IRIS_DATA: &str = include_str!("iris.csv"); -#[derive(Eq, PartialEq, Debug)] +#[derive(Eq, PartialEq, Debug, Clone, Copy)] pub enum IrisType { - Setosa, - Versicolor, - Virginica, + Setosa = 0, + Versicolor = 1, + Virginica = 2, } impl FromStr for IrisType { @@ -25,25 +29,29 @@ impl FromStr for IrisType { } #[derive(PartialEq, Debug)] -pub struct Iris { +pub struct Iris { pub class: IrisType, - pub petal_length: f32, - pub petal_width: f32, - pub sepal_length: f32, - pub sepal_width: f32, + pub petal_length: A, + pub petal_width: A, + pub sepal_length: A, + pub sepal_width: A, } -pub fn import() -> Vec { +pub fn import() -> Vec> +where + A: FromStr, + B: Debug, +{ let mut reader = ReaderBuilder::new() .has_headers(false) .from_reader(Cursor::new(IRIS_DATA)); let mut output = Vec::new(); for record in reader.records() { let record = record.unwrap(); - let petal_length = f32::from_str(&record[0]).unwrap(); - let petal_width = f32::from_str(&record[1]).unwrap(); - let sepal_length = f32::from_str(&record[2]).unwrap(); - let sepal_width = f32::from_str(&record[3]).unwrap(); + let petal_length = A::from_str(&record[0]).unwrap(); + let petal_width = A::from_str(&record[1]).unwrap(); + let sepal_length = A::from_str(&record[2]).unwrap(); + let sepal_width = A::from_str(&record[3]).unwrap(); let class = IrisType::from_str(&record[4]).unwrap(); output.push(Iris { class, @@ -57,21 +65,46 @@ pub fn import() -> Vec { output } -pub(crate) const EXPECTED_FIRST: Iris = Iris { - class: IrisType::Setosa, - petal_length: 5.1, - petal_width: 3.5, - sepal_length: 1.4, - sepal_width: 0.2, -}; +impl Iris { + pub fn one_hot(&self) -> (RankedDifferentiable, RankedDifferentiable) + where + A: Clone + Zero + One, + { + let vec = vec![ + RankedDifferentiable::of_scalar(Scalar::make(self.petal_length.clone())), + RankedDifferentiable::of_scalar(Scalar::make(self.petal_width.clone())), + RankedDifferentiable::of_scalar(Scalar::make(self.sepal_length.clone())), + RankedDifferentiable::of_scalar(Scalar::make(self.sepal_width.clone())), + ]; + + let mut one_hot = vec![A::zero(); 3]; + one_hot[self.class as usize] = A::one(); + let one_hot = one_hot + .iter() + .map(|x| RankedDifferentiable::of_scalar(Scalar::make(x.clone()))) + .collect(); + ( + RankedDifferentiable::of_vector(vec), + RankedDifferentiable::of_vector(one_hot), + ) + } +} #[cfg(test)] mod test { - use crate::iris::import; + use crate::iris::{import, Iris, IrisType}; + + const EXPECTED_FIRST: Iris = Iris { + class: IrisType::Setosa, + petal_length: 5.1, + petal_width: 3.5, + sepal_length: 1.4, + sepal_width: 0.2, + }; #[test] fn first_element() { let irises = import(); - assert_eq!(irises[0], crate::iris::EXPECTED_FIRST); + assert_eq!(irises[0], EXPECTED_FIRST); } } diff --git a/little_learner_app/src/main.rs b/little_learner_app/src/main.rs index ac779d9..b917512 100644 --- a/little_learner_app/src/main.rs +++ b/little_learner_app/src/main.rs @@ -2,6 +2,7 @@ #![feature(generic_const_exprs)] use crate::rms_example::rms_example; +use little_learner::auto_diff::RankedDifferentiable; mod iris; mod rms_example; @@ -9,6 +10,14 @@ mod rms_example; fn main() { rms_example(); - let irises = iris::import(); - assert_eq!(irises[0], crate::iris::EXPECTED_FIRST); + let irises = iris::import::(); + let mut xs = Vec::with_capacity(irises.len()); + let mut ys = Vec::with_capacity(irises.len()); + for iris in irises { + let (x, y) = iris.one_hot(); + xs.push(x); + ys.push(y); + } + let _xs = RankedDifferentiable::of_vector(xs); + let _ys = RankedDifferentiable::of_vector(ys); }