One-hot encoding (#30)
This commit is contained in:
@@ -1,14 +1,18 @@
|
|||||||
use csv::ReaderBuilder;
|
use csv::ReaderBuilder;
|
||||||
|
use little_learner::auto_diff::RankedDifferentiable;
|
||||||
|
use little_learner::scalar::Scalar;
|
||||||
|
use little_learner::traits::{One, Zero};
|
||||||
|
use std::fmt::Debug;
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
const IRIS_DATA: &str = include_str!("iris.csv");
|
const IRIS_DATA: &str = include_str!("iris.csv");
|
||||||
|
|
||||||
#[derive(Eq, PartialEq, Debug)]
|
#[derive(Eq, PartialEq, Debug, Clone, Copy)]
|
||||||
pub enum IrisType {
|
pub enum IrisType {
|
||||||
Setosa,
|
Setosa = 0,
|
||||||
Versicolor,
|
Versicolor = 1,
|
||||||
Virginica,
|
Virginica = 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FromStr for IrisType {
|
impl FromStr for IrisType {
|
||||||
@@ -25,25 +29,29 @@ impl FromStr for IrisType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(PartialEq, Debug)]
|
#[derive(PartialEq, Debug)]
|
||||||
pub struct Iris {
|
pub struct Iris<A> {
|
||||||
pub class: IrisType,
|
pub class: IrisType,
|
||||||
pub petal_length: f32,
|
pub petal_length: A,
|
||||||
pub petal_width: f32,
|
pub petal_width: A,
|
||||||
pub sepal_length: f32,
|
pub sepal_length: A,
|
||||||
pub sepal_width: f32,
|
pub sepal_width: A,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn import() -> Vec<Iris> {
|
pub fn import<A, B>() -> Vec<Iris<A>>
|
||||||
|
where
|
||||||
|
A: FromStr<Err = B>,
|
||||||
|
B: Debug,
|
||||||
|
{
|
||||||
let mut reader = ReaderBuilder::new()
|
let mut reader = ReaderBuilder::new()
|
||||||
.has_headers(false)
|
.has_headers(false)
|
||||||
.from_reader(Cursor::new(IRIS_DATA));
|
.from_reader(Cursor::new(IRIS_DATA));
|
||||||
let mut output = Vec::new();
|
let mut output = Vec::new();
|
||||||
for record in reader.records() {
|
for record in reader.records() {
|
||||||
let record = record.unwrap();
|
let record = record.unwrap();
|
||||||
let petal_length = f32::from_str(&record[0]).unwrap();
|
let petal_length = A::from_str(&record[0]).unwrap();
|
||||||
let petal_width = f32::from_str(&record[1]).unwrap();
|
let petal_width = A::from_str(&record[1]).unwrap();
|
||||||
let sepal_length = f32::from_str(&record[2]).unwrap();
|
let sepal_length = A::from_str(&record[2]).unwrap();
|
||||||
let sepal_width = f32::from_str(&record[3]).unwrap();
|
let sepal_width = A::from_str(&record[3]).unwrap();
|
||||||
let class = IrisType::from_str(&record[4]).unwrap();
|
let class = IrisType::from_str(&record[4]).unwrap();
|
||||||
output.push(Iris {
|
output.push(Iris {
|
||||||
class,
|
class,
|
||||||
@@ -57,21 +65,46 @@ pub fn import() -> Vec<Iris> {
|
|||||||
output
|
output
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) const EXPECTED_FIRST: Iris = Iris {
|
impl<A> Iris<A> {
|
||||||
class: IrisType::Setosa,
|
pub fn one_hot(&self) -> (RankedDifferentiable<A, 1>, RankedDifferentiable<A, 1>)
|
||||||
petal_length: 5.1,
|
where
|
||||||
petal_width: 3.5,
|
A: Clone + Zero + One,
|
||||||
sepal_length: 1.4,
|
{
|
||||||
sepal_width: 0.2,
|
let vec = vec![
|
||||||
};
|
RankedDifferentiable::of_scalar(Scalar::make(self.petal_length.clone())),
|
||||||
|
RankedDifferentiable::of_scalar(Scalar::make(self.petal_width.clone())),
|
||||||
|
RankedDifferentiable::of_scalar(Scalar::make(self.sepal_length.clone())),
|
||||||
|
RankedDifferentiable::of_scalar(Scalar::make(self.sepal_width.clone())),
|
||||||
|
];
|
||||||
|
|
||||||
|
let mut one_hot = vec![A::zero(); 3];
|
||||||
|
one_hot[self.class as usize] = A::one();
|
||||||
|
let one_hot = one_hot
|
||||||
|
.iter()
|
||||||
|
.map(|x| RankedDifferentiable::of_scalar(Scalar::make(x.clone())))
|
||||||
|
.collect();
|
||||||
|
(
|
||||||
|
RankedDifferentiable::of_vector(vec),
|
||||||
|
RankedDifferentiable::of_vector(one_hot),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use crate::iris::import;
|
use crate::iris::{import, Iris, IrisType};
|
||||||
|
|
||||||
|
const EXPECTED_FIRST: Iris<f32> = Iris {
|
||||||
|
class: IrisType::Setosa,
|
||||||
|
petal_length: 5.1,
|
||||||
|
petal_width: 3.5,
|
||||||
|
sepal_length: 1.4,
|
||||||
|
sepal_width: 0.2,
|
||||||
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn first_element() {
|
fn first_element() {
|
||||||
let irises = import();
|
let irises = import();
|
||||||
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST);
|
assert_eq!(irises[0], EXPECTED_FIRST);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -2,6 +2,7 @@
|
|||||||
#![feature(generic_const_exprs)]
|
#![feature(generic_const_exprs)]
|
||||||
|
|
||||||
use crate::rms_example::rms_example;
|
use crate::rms_example::rms_example;
|
||||||
|
use little_learner::auto_diff::RankedDifferentiable;
|
||||||
|
|
||||||
mod iris;
|
mod iris;
|
||||||
mod rms_example;
|
mod rms_example;
|
||||||
@@ -9,6 +10,14 @@ mod rms_example;
|
|||||||
fn main() {
|
fn main() {
|
||||||
rms_example();
|
rms_example();
|
||||||
|
|
||||||
let irises = iris::import();
|
let irises = iris::import::<f64, _>();
|
||||||
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST);
|
let mut xs = Vec::with_capacity(irises.len());
|
||||||
|
let mut ys = Vec::with_capacity(irises.len());
|
||||||
|
for iris in irises {
|
||||||
|
let (x, y) = iris.one_hot();
|
||||||
|
xs.push(x);
|
||||||
|
ys.push(y);
|
||||||
|
}
|
||||||
|
let _xs = RankedDifferentiable::of_vector(xs);
|
||||||
|
let _ys = RankedDifferentiable::of_vector(ys);
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user