One-hot encoding (#30)

This commit is contained in:
Patrick Stevens
2023-06-17 22:14:37 +01:00
committed by GitHub
parent fd55cd1c5f
commit bdb5d8e192
2 changed files with 67 additions and 25 deletions

View File

@@ -1,14 +1,18 @@
use csv::ReaderBuilder; use csv::ReaderBuilder;
use little_learner::auto_diff::RankedDifferentiable;
use little_learner::scalar::Scalar;
use little_learner::traits::{One, Zero};
use std::fmt::Debug;
use std::io::Cursor; use std::io::Cursor;
use std::str::FromStr; use std::str::FromStr;
const IRIS_DATA: &str = include_str!("iris.csv"); const IRIS_DATA: &str = include_str!("iris.csv");
#[derive(Eq, PartialEq, Debug)] #[derive(Eq, PartialEq, Debug, Clone, Copy)]
pub enum IrisType { pub enum IrisType {
Setosa, Setosa = 0,
Versicolor, Versicolor = 1,
Virginica, Virginica = 2,
} }
impl FromStr for IrisType { impl FromStr for IrisType {
@@ -25,25 +29,29 @@ impl FromStr for IrisType {
} }
#[derive(PartialEq, Debug)] #[derive(PartialEq, Debug)]
pub struct Iris { pub struct Iris<A> {
pub class: IrisType, pub class: IrisType,
pub petal_length: f32, pub petal_length: A,
pub petal_width: f32, pub petal_width: A,
pub sepal_length: f32, pub sepal_length: A,
pub sepal_width: f32, pub sepal_width: A,
} }
pub fn import() -> Vec<Iris> { pub fn import<A, B>() -> Vec<Iris<A>>
where
A: FromStr<Err = B>,
B: Debug,
{
let mut reader = ReaderBuilder::new() let mut reader = ReaderBuilder::new()
.has_headers(false) .has_headers(false)
.from_reader(Cursor::new(IRIS_DATA)); .from_reader(Cursor::new(IRIS_DATA));
let mut output = Vec::new(); let mut output = Vec::new();
for record in reader.records() { for record in reader.records() {
let record = record.unwrap(); let record = record.unwrap();
let petal_length = f32::from_str(&record[0]).unwrap(); let petal_length = A::from_str(&record[0]).unwrap();
let petal_width = f32::from_str(&record[1]).unwrap(); let petal_width = A::from_str(&record[1]).unwrap();
let sepal_length = f32::from_str(&record[2]).unwrap(); let sepal_length = A::from_str(&record[2]).unwrap();
let sepal_width = f32::from_str(&record[3]).unwrap(); let sepal_width = A::from_str(&record[3]).unwrap();
let class = IrisType::from_str(&record[4]).unwrap(); let class = IrisType::from_str(&record[4]).unwrap();
output.push(Iris { output.push(Iris {
class, class,
@@ -57,21 +65,46 @@ pub fn import() -> Vec<Iris> {
output output
} }
pub(crate) const EXPECTED_FIRST: Iris = Iris { impl<A> Iris<A> {
class: IrisType::Setosa, pub fn one_hot(&self) -> (RankedDifferentiable<A, 1>, RankedDifferentiable<A, 1>)
petal_length: 5.1, where
petal_width: 3.5, A: Clone + Zero + One,
sepal_length: 1.4, {
sepal_width: 0.2, let vec = vec![
}; RankedDifferentiable::of_scalar(Scalar::make(self.petal_length.clone())),
RankedDifferentiable::of_scalar(Scalar::make(self.petal_width.clone())),
RankedDifferentiable::of_scalar(Scalar::make(self.sepal_length.clone())),
RankedDifferentiable::of_scalar(Scalar::make(self.sepal_width.clone())),
];
let mut one_hot = vec![A::zero(); 3];
one_hot[self.class as usize] = A::one();
let one_hot = one_hot
.iter()
.map(|x| RankedDifferentiable::of_scalar(Scalar::make(x.clone())))
.collect();
(
RankedDifferentiable::of_vector(vec),
RankedDifferentiable::of_vector(one_hot),
)
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::iris::import; use crate::iris::{import, Iris, IrisType};
const EXPECTED_FIRST: Iris<f32> = Iris {
class: IrisType::Setosa,
petal_length: 5.1,
petal_width: 3.5,
sepal_length: 1.4,
sepal_width: 0.2,
};
#[test] #[test]
fn first_element() { fn first_element() {
let irises = import(); let irises = import();
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST); assert_eq!(irises[0], EXPECTED_FIRST);
} }
} }

View File

@@ -2,6 +2,7 @@
#![feature(generic_const_exprs)] #![feature(generic_const_exprs)]
use crate::rms_example::rms_example; use crate::rms_example::rms_example;
use little_learner::auto_diff::RankedDifferentiable;
mod iris; mod iris;
mod rms_example; mod rms_example;
@@ -9,6 +10,14 @@ mod rms_example;
fn main() { fn main() {
rms_example(); rms_example();
let irises = iris::import(); let irises = iris::import::<f64, _>();
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST); let mut xs = Vec::with_capacity(irises.len());
let mut ys = Vec::with_capacity(irises.len());
for iris in irises {
let (x, y) = iris.one_hot();
xs.push(x);
ys.push(y);
}
let _xs = RankedDifferentiable::of_vector(xs);
let _ys = RankedDifferentiable::of_vector(ys);
} }