One-hot encoding (#30)
This commit is contained in:
@@ -1,14 +1,18 @@
|
||||
use csv::ReaderBuilder;
|
||||
use little_learner::auto_diff::RankedDifferentiable;
|
||||
use little_learner::scalar::Scalar;
|
||||
use little_learner::traits::{One, Zero};
|
||||
use std::fmt::Debug;
|
||||
use std::io::Cursor;
|
||||
use std::str::FromStr;
|
||||
|
||||
const IRIS_DATA: &str = include_str!("iris.csv");
|
||||
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
#[derive(Eq, PartialEq, Debug, Clone, Copy)]
|
||||
pub enum IrisType {
|
||||
Setosa,
|
||||
Versicolor,
|
||||
Virginica,
|
||||
Setosa = 0,
|
||||
Versicolor = 1,
|
||||
Virginica = 2,
|
||||
}
|
||||
|
||||
impl FromStr for IrisType {
|
||||
@@ -25,25 +29,29 @@ impl FromStr for IrisType {
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
pub struct Iris {
|
||||
pub struct Iris<A> {
|
||||
pub class: IrisType,
|
||||
pub petal_length: f32,
|
||||
pub petal_width: f32,
|
||||
pub sepal_length: f32,
|
||||
pub sepal_width: f32,
|
||||
pub petal_length: A,
|
||||
pub petal_width: A,
|
||||
pub sepal_length: A,
|
||||
pub sepal_width: A,
|
||||
}
|
||||
|
||||
pub fn import() -> Vec<Iris> {
|
||||
pub fn import<A, B>() -> Vec<Iris<A>>
|
||||
where
|
||||
A: FromStr<Err = B>,
|
||||
B: Debug,
|
||||
{
|
||||
let mut reader = ReaderBuilder::new()
|
||||
.has_headers(false)
|
||||
.from_reader(Cursor::new(IRIS_DATA));
|
||||
let mut output = Vec::new();
|
||||
for record in reader.records() {
|
||||
let record = record.unwrap();
|
||||
let petal_length = f32::from_str(&record[0]).unwrap();
|
||||
let petal_width = f32::from_str(&record[1]).unwrap();
|
||||
let sepal_length = f32::from_str(&record[2]).unwrap();
|
||||
let sepal_width = f32::from_str(&record[3]).unwrap();
|
||||
let petal_length = A::from_str(&record[0]).unwrap();
|
||||
let petal_width = A::from_str(&record[1]).unwrap();
|
||||
let sepal_length = A::from_str(&record[2]).unwrap();
|
||||
let sepal_width = A::from_str(&record[3]).unwrap();
|
||||
let class = IrisType::from_str(&record[4]).unwrap();
|
||||
output.push(Iris {
|
||||
class,
|
||||
@@ -57,21 +65,46 @@ pub fn import() -> Vec<Iris> {
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) const EXPECTED_FIRST: Iris = Iris {
|
||||
class: IrisType::Setosa,
|
||||
petal_length: 5.1,
|
||||
petal_width: 3.5,
|
||||
sepal_length: 1.4,
|
||||
sepal_width: 0.2,
|
||||
};
|
||||
impl<A> Iris<A> {
|
||||
pub fn one_hot(&self) -> (RankedDifferentiable<A, 1>, RankedDifferentiable<A, 1>)
|
||||
where
|
||||
A: Clone + Zero + One,
|
||||
{
|
||||
let vec = vec![
|
||||
RankedDifferentiable::of_scalar(Scalar::make(self.petal_length.clone())),
|
||||
RankedDifferentiable::of_scalar(Scalar::make(self.petal_width.clone())),
|
||||
RankedDifferentiable::of_scalar(Scalar::make(self.sepal_length.clone())),
|
||||
RankedDifferentiable::of_scalar(Scalar::make(self.sepal_width.clone())),
|
||||
];
|
||||
|
||||
let mut one_hot = vec![A::zero(); 3];
|
||||
one_hot[self.class as usize] = A::one();
|
||||
let one_hot = one_hot
|
||||
.iter()
|
||||
.map(|x| RankedDifferentiable::of_scalar(Scalar::make(x.clone())))
|
||||
.collect();
|
||||
(
|
||||
RankedDifferentiable::of_vector(vec),
|
||||
RankedDifferentiable::of_vector(one_hot),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::iris::import;
|
||||
use crate::iris::{import, Iris, IrisType};
|
||||
|
||||
const EXPECTED_FIRST: Iris<f32> = Iris {
|
||||
class: IrisType::Setosa,
|
||||
petal_length: 5.1,
|
||||
petal_width: 3.5,
|
||||
sepal_length: 1.4,
|
||||
sepal_width: 0.2,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn first_element() {
|
||||
let irises = import();
|
||||
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST);
|
||||
assert_eq!(irises[0], EXPECTED_FIRST);
|
||||
}
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@
|
||||
#![feature(generic_const_exprs)]
|
||||
|
||||
use crate::rms_example::rms_example;
|
||||
use little_learner::auto_diff::RankedDifferentiable;
|
||||
|
||||
mod iris;
|
||||
mod rms_example;
|
||||
@@ -9,6 +10,14 @@ mod rms_example;
|
||||
fn main() {
|
||||
rms_example();
|
||||
|
||||
let irises = iris::import();
|
||||
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST);
|
||||
let irises = iris::import::<f64, _>();
|
||||
let mut xs = Vec::with_capacity(irises.len());
|
||||
let mut ys = Vec::with_capacity(irises.len());
|
||||
for iris in irises {
|
||||
let (x, y) = iris.one_hot();
|
||||
xs.push(x);
|
||||
ys.push(y);
|
||||
}
|
||||
let _xs = RankedDifferentiable::of_vector(xs);
|
||||
let _ys = RankedDifferentiable::of_vector(ys);
|
||||
}
|
||||
|
Reference in New Issue
Block a user