Iris data (#29)
This commit is contained in:
@@ -10,3 +10,4 @@ immutable-chunkmap = "1.0.5"
|
||||
ordered-float = "3.6.0"
|
||||
little_learner = { path = "../little_learner" }
|
||||
rand = "0.8.5"
|
||||
csv = "1.2.2"
|
||||
|
151
little_learner_app/src/iris.csv
Executable file
151
little_learner_app/src/iris.csv
Executable file
@@ -0,0 +1,151 @@
|
||||
5.1,3.5,1.4,0.2,Iris-setosa
|
||||
4.9,3.0,1.4,0.2,Iris-setosa
|
||||
4.7,3.2,1.3,0.2,Iris-setosa
|
||||
4.6,3.1,1.5,0.2,Iris-setosa
|
||||
5.0,3.6,1.4,0.2,Iris-setosa
|
||||
5.4,3.9,1.7,0.4,Iris-setosa
|
||||
4.6,3.4,1.4,0.3,Iris-setosa
|
||||
5.0,3.4,1.5,0.2,Iris-setosa
|
||||
4.4,2.9,1.4,0.2,Iris-setosa
|
||||
4.9,3.1,1.5,0.1,Iris-setosa
|
||||
5.4,3.7,1.5,0.2,Iris-setosa
|
||||
4.8,3.4,1.6,0.2,Iris-setosa
|
||||
4.8,3.0,1.4,0.1,Iris-setosa
|
||||
4.3,3.0,1.1,0.1,Iris-setosa
|
||||
5.8,4.0,1.2,0.2,Iris-setosa
|
||||
5.7,4.4,1.5,0.4,Iris-setosa
|
||||
5.4,3.9,1.3,0.4,Iris-setosa
|
||||
5.1,3.5,1.4,0.3,Iris-setosa
|
||||
5.7,3.8,1.7,0.3,Iris-setosa
|
||||
5.1,3.8,1.5,0.3,Iris-setosa
|
||||
5.4,3.4,1.7,0.2,Iris-setosa
|
||||
5.1,3.7,1.5,0.4,Iris-setosa
|
||||
4.6,3.6,1.0,0.2,Iris-setosa
|
||||
5.1,3.3,1.7,0.5,Iris-setosa
|
||||
4.8,3.4,1.9,0.2,Iris-setosa
|
||||
5.0,3.0,1.6,0.2,Iris-setosa
|
||||
5.0,3.4,1.6,0.4,Iris-setosa
|
||||
5.2,3.5,1.5,0.2,Iris-setosa
|
||||
5.2,3.4,1.4,0.2,Iris-setosa
|
||||
4.7,3.2,1.6,0.2,Iris-setosa
|
||||
4.8,3.1,1.6,0.2,Iris-setosa
|
||||
5.4,3.4,1.5,0.4,Iris-setosa
|
||||
5.2,4.1,1.5,0.1,Iris-setosa
|
||||
5.5,4.2,1.4,0.2,Iris-setosa
|
||||
4.9,3.1,1.5,0.1,Iris-setosa
|
||||
5.0,3.2,1.2,0.2,Iris-setosa
|
||||
5.5,3.5,1.3,0.2,Iris-setosa
|
||||
4.9,3.1,1.5,0.1,Iris-setosa
|
||||
4.4,3.0,1.3,0.2,Iris-setosa
|
||||
5.1,3.4,1.5,0.2,Iris-setosa
|
||||
5.0,3.5,1.3,0.3,Iris-setosa
|
||||
4.5,2.3,1.3,0.3,Iris-setosa
|
||||
4.4,3.2,1.3,0.2,Iris-setosa
|
||||
5.0,3.5,1.6,0.6,Iris-setosa
|
||||
5.1,3.8,1.9,0.4,Iris-setosa
|
||||
4.8,3.0,1.4,0.3,Iris-setosa
|
||||
5.1,3.8,1.6,0.2,Iris-setosa
|
||||
4.6,3.2,1.4,0.2,Iris-setosa
|
||||
5.3,3.7,1.5,0.2,Iris-setosa
|
||||
5.0,3.3,1.4,0.2,Iris-setosa
|
||||
7.0,3.2,4.7,1.4,Iris-versicolor
|
||||
6.4,3.2,4.5,1.5,Iris-versicolor
|
||||
6.9,3.1,4.9,1.5,Iris-versicolor
|
||||
5.5,2.3,4.0,1.3,Iris-versicolor
|
||||
6.5,2.8,4.6,1.5,Iris-versicolor
|
||||
5.7,2.8,4.5,1.3,Iris-versicolor
|
||||
6.3,3.3,4.7,1.6,Iris-versicolor
|
||||
4.9,2.4,3.3,1.0,Iris-versicolor
|
||||
6.6,2.9,4.6,1.3,Iris-versicolor
|
||||
5.2,2.7,3.9,1.4,Iris-versicolor
|
||||
5.0,2.0,3.5,1.0,Iris-versicolor
|
||||
5.9,3.0,4.2,1.5,Iris-versicolor
|
||||
6.0,2.2,4.0,1.0,Iris-versicolor
|
||||
6.1,2.9,4.7,1.4,Iris-versicolor
|
||||
5.6,2.9,3.6,1.3,Iris-versicolor
|
||||
6.7,3.1,4.4,1.4,Iris-versicolor
|
||||
5.6,3.0,4.5,1.5,Iris-versicolor
|
||||
5.8,2.7,4.1,1.0,Iris-versicolor
|
||||
6.2,2.2,4.5,1.5,Iris-versicolor
|
||||
5.6,2.5,3.9,1.1,Iris-versicolor
|
||||
5.9,3.2,4.8,1.8,Iris-versicolor
|
||||
6.1,2.8,4.0,1.3,Iris-versicolor
|
||||
6.3,2.5,4.9,1.5,Iris-versicolor
|
||||
6.1,2.8,4.7,1.2,Iris-versicolor
|
||||
6.4,2.9,4.3,1.3,Iris-versicolor
|
||||
6.6,3.0,4.4,1.4,Iris-versicolor
|
||||
6.8,2.8,4.8,1.4,Iris-versicolor
|
||||
6.7,3.0,5.0,1.7,Iris-versicolor
|
||||
6.0,2.9,4.5,1.5,Iris-versicolor
|
||||
5.7,2.6,3.5,1.0,Iris-versicolor
|
||||
5.5,2.4,3.8,1.1,Iris-versicolor
|
||||
5.5,2.4,3.7,1.0,Iris-versicolor
|
||||
5.8,2.7,3.9,1.2,Iris-versicolor
|
||||
6.0,2.7,5.1,1.6,Iris-versicolor
|
||||
5.4,3.0,4.5,1.5,Iris-versicolor
|
||||
6.0,3.4,4.5,1.6,Iris-versicolor
|
||||
6.7,3.1,4.7,1.5,Iris-versicolor
|
||||
6.3,2.3,4.4,1.3,Iris-versicolor
|
||||
5.6,3.0,4.1,1.3,Iris-versicolor
|
||||
5.5,2.5,4.0,1.3,Iris-versicolor
|
||||
5.5,2.6,4.4,1.2,Iris-versicolor
|
||||
6.1,3.0,4.6,1.4,Iris-versicolor
|
||||
5.8,2.6,4.0,1.2,Iris-versicolor
|
||||
5.0,2.3,3.3,1.0,Iris-versicolor
|
||||
5.6,2.7,4.2,1.3,Iris-versicolor
|
||||
5.7,3.0,4.2,1.2,Iris-versicolor
|
||||
5.7,2.9,4.2,1.3,Iris-versicolor
|
||||
6.2,2.9,4.3,1.3,Iris-versicolor
|
||||
5.1,2.5,3.0,1.1,Iris-versicolor
|
||||
5.7,2.8,4.1,1.3,Iris-versicolor
|
||||
6.3,3.3,6.0,2.5,Iris-virginica
|
||||
5.8,2.7,5.1,1.9,Iris-virginica
|
||||
7.1,3.0,5.9,2.1,Iris-virginica
|
||||
6.3,2.9,5.6,1.8,Iris-virginica
|
||||
6.5,3.0,5.8,2.2,Iris-virginica
|
||||
7.6,3.0,6.6,2.1,Iris-virginica
|
||||
4.9,2.5,4.5,1.7,Iris-virginica
|
||||
7.3,2.9,6.3,1.8,Iris-virginica
|
||||
6.7,2.5,5.8,1.8,Iris-virginica
|
||||
7.2,3.6,6.1,2.5,Iris-virginica
|
||||
6.5,3.2,5.1,2.0,Iris-virginica
|
||||
6.4,2.7,5.3,1.9,Iris-virginica
|
||||
6.8,3.0,5.5,2.1,Iris-virginica
|
||||
5.7,2.5,5.0,2.0,Iris-virginica
|
||||
5.8,2.8,5.1,2.4,Iris-virginica
|
||||
6.4,3.2,5.3,2.3,Iris-virginica
|
||||
6.5,3.0,5.5,1.8,Iris-virginica
|
||||
7.7,3.8,6.7,2.2,Iris-virginica
|
||||
7.7,2.6,6.9,2.3,Iris-virginica
|
||||
6.0,2.2,5.0,1.5,Iris-virginica
|
||||
6.9,3.2,5.7,2.3,Iris-virginica
|
||||
5.6,2.8,4.9,2.0,Iris-virginica
|
||||
7.7,2.8,6.7,2.0,Iris-virginica
|
||||
6.3,2.7,4.9,1.8,Iris-virginica
|
||||
6.7,3.3,5.7,2.1,Iris-virginica
|
||||
7.2,3.2,6.0,1.8,Iris-virginica
|
||||
6.2,2.8,4.8,1.8,Iris-virginica
|
||||
6.1,3.0,4.9,1.8,Iris-virginica
|
||||
6.4,2.8,5.6,2.1,Iris-virginica
|
||||
7.2,3.0,5.8,1.6,Iris-virginica
|
||||
7.4,2.8,6.1,1.9,Iris-virginica
|
||||
7.9,3.8,6.4,2.0,Iris-virginica
|
||||
6.4,2.8,5.6,2.2,Iris-virginica
|
||||
6.3,2.8,5.1,1.5,Iris-virginica
|
||||
6.1,2.6,5.6,1.4,Iris-virginica
|
||||
7.7,3.0,6.1,2.3,Iris-virginica
|
||||
6.3,3.4,5.6,2.4,Iris-virginica
|
||||
6.4,3.1,5.5,1.8,Iris-virginica
|
||||
6.0,3.0,4.8,1.8,Iris-virginica
|
||||
6.9,3.1,5.4,2.1,Iris-virginica
|
||||
6.7,3.1,5.6,2.4,Iris-virginica
|
||||
6.9,3.1,5.1,2.3,Iris-virginica
|
||||
5.8,2.7,5.1,1.9,Iris-virginica
|
||||
6.8,3.2,5.9,2.3,Iris-virginica
|
||||
6.7,3.3,5.7,2.5,Iris-virginica
|
||||
6.7,3.0,5.2,2.3,Iris-virginica
|
||||
6.3,2.5,5.0,1.9,Iris-virginica
|
||||
6.5,3.0,5.2,2.0,Iris-virginica
|
||||
6.2,3.4,5.4,2.3,Iris-virginica
|
||||
5.9,3.0,5.1,1.8,Iris-virginica
|
||||
|
|
77
little_learner_app/src/iris.rs
Normal file
77
little_learner_app/src/iris.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use csv::ReaderBuilder;
|
||||
use std::io::Cursor;
|
||||
use std::str::FromStr;
|
||||
|
||||
const IRIS_DATA: &str = include_str!("iris.csv");
|
||||
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub enum IrisType {
|
||||
Setosa,
|
||||
Versicolor,
|
||||
Virginica,
|
||||
}
|
||||
|
||||
impl FromStr for IrisType {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"Iris-virginica" => Ok(IrisType::Virginica),
|
||||
"Iris-versicolor" => Ok(IrisType::Versicolor),
|
||||
"Iris-setosa" => Ok(IrisType::Setosa),
|
||||
_ => Err(String::from(s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
pub struct Iris {
|
||||
pub class: IrisType,
|
||||
pub petal_length: f32,
|
||||
pub petal_width: f32,
|
||||
pub sepal_length: f32,
|
||||
pub sepal_width: f32,
|
||||
}
|
||||
|
||||
pub fn import() -> Vec<Iris> {
|
||||
let mut reader = ReaderBuilder::new()
|
||||
.has_headers(false)
|
||||
.from_reader(Cursor::new(IRIS_DATA));
|
||||
let mut output = Vec::new();
|
||||
for record in reader.records() {
|
||||
let record = record.unwrap();
|
||||
let petal_length = f32::from_str(&record[0]).unwrap();
|
||||
let petal_width = f32::from_str(&record[1]).unwrap();
|
||||
let sepal_length = f32::from_str(&record[2]).unwrap();
|
||||
let sepal_width = f32::from_str(&record[3]).unwrap();
|
||||
let class = IrisType::from_str(&record[4]).unwrap();
|
||||
output.push(Iris {
|
||||
class,
|
||||
petal_length,
|
||||
petal_width,
|
||||
sepal_length,
|
||||
sepal_width,
|
||||
});
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) const EXPECTED_FIRST: Iris = Iris {
|
||||
class: IrisType::Setosa,
|
||||
petal_length: 5.1,
|
||||
petal_width: 3.5,
|
||||
sepal_length: 1.4,
|
||||
sepal_width: 0.2,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::iris::import;
|
||||
|
||||
#[test]
|
||||
fn first_element() {
|
||||
let irises = import();
|
||||
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST);
|
||||
}
|
||||
}
|
@@ -1,71 +1,14 @@
|
||||
#![allow(incomplete_features)]
|
||||
#![feature(generic_const_exprs)]
|
||||
|
||||
use little_learner::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
|
||||
use crate::rms_example::rms_example;
|
||||
|
||||
use little_learner::gradient_descent::gradient_descent;
|
||||
use little_learner::hyper;
|
||||
use little_learner::loss::predict_plane;
|
||||
use little_learner::not_nan::{to_not_nan_1, to_not_nan_2};
|
||||
use little_learner::predictor;
|
||||
use little_learner::scalar::Scalar;
|
||||
use little_learner::traits::Zero;
|
||||
use ordered_float::NotNan;
|
||||
|
||||
const PLANE_XS: [[f64; 2]; 6] = [
|
||||
[1.0, 2.05],
|
||||
[1.0, 3.0],
|
||||
[2.0, 2.0],
|
||||
[2.0, 3.91],
|
||||
[3.0, 6.13],
|
||||
[4.0, 8.09],
|
||||
];
|
||||
const PLANE_YS: [f64; 6] = [13.99, 15.99, 18.0, 22.4, 30.2, 37.94];
|
||||
mod iris;
|
||||
mod rms_example;
|
||||
|
||||
fn main() {
|
||||
let beta = NotNan::new(0.9).expect("not nan");
|
||||
let stabilizer = NotNan::new(0.000_000_01).expect("not nan");
|
||||
let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.01).expect("not nan"), 3000)
|
||||
.with_stabilizer(stabilizer)
|
||||
.with_beta(beta);
|
||||
rms_example();
|
||||
|
||||
let iterated = {
|
||||
let xs = to_not_nan_2(PLANE_XS);
|
||||
let ys = to_not_nan_1(PLANE_YS);
|
||||
let zero_params = [
|
||||
RankedDifferentiable::of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()])
|
||||
.to_unranked(),
|
||||
Differentiable::of_scalar(Scalar::zero()),
|
||||
];
|
||||
|
||||
gradient_descent(
|
||||
hyper,
|
||||
&xs,
|
||||
RankedDifferentiableTagged::of_slice_2::<_, 2>,
|
||||
&ys,
|
||||
zero_params,
|
||||
predictor::rms(predict_plane),
|
||||
hyper::RmsGradientDescent::to_immutable,
|
||||
)
|
||||
};
|
||||
|
||||
let [theta0, theta1] = iterated;
|
||||
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
|
||||
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
|
||||
|
||||
let fitted_theta0 = theta0
|
||||
.collect()
|
||||
.iter()
|
||||
.map(|x| x.into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(
|
||||
fitted_theta0,
|
||||
[3.974_645_444_172_085, 1.971_454_922_077_495]
|
||||
);
|
||||
assert_eq!(fitted_theta1, 6.164_579_048_274_036);
|
||||
let irises = iris::import();
|
||||
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {}
|
||||
|
65
little_learner_app/src/rms_example.rs
Normal file
65
little_learner_app/src/rms_example.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use little_learner::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
|
||||
|
||||
use little_learner::gradient_descent::gradient_descent;
|
||||
use little_learner::hyper;
|
||||
use little_learner::loss::predict_plane;
|
||||
use little_learner::not_nan::{to_not_nan_1, to_not_nan_2};
|
||||
use little_learner::predictor;
|
||||
use little_learner::scalar::Scalar;
|
||||
use little_learner::traits::Zero;
|
||||
use ordered_float::NotNan;
|
||||
|
||||
const PLANE_XS: [[f64; 2]; 6] = [
|
||||
[1.0, 2.05],
|
||||
[1.0, 3.0],
|
||||
[2.0, 2.0],
|
||||
[2.0, 3.91],
|
||||
[3.0, 6.13],
|
||||
[4.0, 8.09],
|
||||
];
|
||||
const PLANE_YS: [f64; 6] = [13.99, 15.99, 18.0, 22.4, 30.2, 37.94];
|
||||
|
||||
pub(crate) fn rms_example() {
|
||||
let beta = NotNan::new(0.9).expect("not nan");
|
||||
let stabilizer = NotNan::new(0.000_000_01).expect("not nan");
|
||||
let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.01).expect("not nan"), 3000)
|
||||
.with_stabilizer(stabilizer)
|
||||
.with_beta(beta);
|
||||
|
||||
let iterated = {
|
||||
let xs = to_not_nan_2(PLANE_XS);
|
||||
let ys = to_not_nan_1(PLANE_YS);
|
||||
let zero_params = [
|
||||
RankedDifferentiable::of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()])
|
||||
.to_unranked(),
|
||||
Differentiable::of_scalar(Scalar::zero()),
|
||||
];
|
||||
|
||||
gradient_descent(
|
||||
hyper,
|
||||
&xs,
|
||||
RankedDifferentiableTagged::of_slice_2::<_, 2>,
|
||||
&ys,
|
||||
zero_params,
|
||||
predictor::rms(predict_plane),
|
||||
hyper::RmsGradientDescent::to_immutable,
|
||||
)
|
||||
};
|
||||
|
||||
let [theta0, theta1] = iterated;
|
||||
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
|
||||
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
|
||||
|
||||
let fitted_theta0 = theta0
|
||||
.collect()
|
||||
.iter()
|
||||
.map(|x| x.into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(
|
||||
fitted_theta0,
|
||||
[3.974_645_444_172_085, 1.971_454_922_077_495]
|
||||
);
|
||||
assert_eq!(fitted_theta1, 6.164_579_048_274_036);
|
||||
}
|
Reference in New Issue
Block a user