From fd55cd1c5f36616a36a4c9382b9e709460014ade Mon Sep 17 00:00:00 2001 From: Patrick Stevens Date: Sat, 17 Jun 2023 19:03:01 +0100 Subject: [PATCH] Iris data (#29) --- Cargo.lock | 46 ++++++++ little_learner/src/block.rs | 50 ++++++++ little_learner/src/ext.rs | 2 +- little_learner/src/gradient_descent.rs | 9 +- little_learner/src/lib.rs | 1 + little_learner/src/smooth.rs | 2 +- little_learner_app/Cargo.toml | 1 + little_learner_app/src/iris.csv | 151 +++++++++++++++++++++++++ little_learner_app/src/iris.rs | 77 +++++++++++++ little_learner_app/src/main.rs | 69 +---------- little_learner_app/src/rms_example.rs | 65 +++++++++++ 11 files changed, 405 insertions(+), 68 deletions(-) create mode 100644 little_learner/src/block.rs create mode 100755 little_learner_app/src/iris.csv create mode 100644 little_learner_app/src/iris.rs create mode 100644 little_learner_app/src/rms_example.rs diff --git a/Cargo.lock b/Cargo.lock index 685a390..923f5c7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,6 +32,27 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "csv" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "626ae34994d3d8d668f4269922248239db4ae42d538b14c398b74a52208e8086" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90" +dependencies = [ + "memchr", +] + [[package]] name = "funty" version = "2.0.0" @@ -60,6 +81,12 @@ dependencies = [ "packed_struct_codegen", ] +[[package]] +name = "itoa" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" + [[package]] name = "libc" version = "0.2.142" @@ -79,12 +106,19 @@ dependencies = [ name = "little_learner_app" version = "0.1.0" dependencies = [ + "csv", "immutable-chunkmap", "little_learner", "ordered-float", "rand", ] +[[package]] +name = "memchr" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" + [[package]] name = "num-traits" version = "0.2.15" @@ -184,6 +218,18 @@ dependencies = [ "getrandom", ] +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + +[[package]] +name = "serde" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" + [[package]] name = "syn" version = "1.0.109" diff --git a/little_learner/src/block.rs b/little_learner/src/block.rs new file mode 100644 index 0000000..576fccd --- /dev/null +++ b/little_learner/src/block.rs @@ -0,0 +1,50 @@ +use crate::auto_diff::{Differentiable, RankedDifferentiableTagged}; +use crate::ext::relu; +use crate::traits::NumLike; + +pub struct Block { + f: F, + ranks: [usize; N], +} + +pub fn compose<'a, A, T, B, C, F, G, const N: usize, const M: usize>( + b1: Block, + b2: Block, + j: usize, +) -> Block C, { N + M }> +where + F: FnOnce(A, &'a [T]) -> B, + G: FnOnce(B, &'a [T]) -> C, + T: 'a, +{ + let mut ranks = [0usize; N + M]; + ranks.copy_from_slice(&b1.ranks[..N]); + ranks[N..(M + N)].copy_from_slice(&b2.ranks[..M]); + Block { + f: move |t, theta| (b2.f)((b1.f)(t, theta), &theta[j..]), + ranks, + } +} + +pub fn dense<'a, 'b, A, Tag>( + input_len: usize, + neuron_count: usize, +) -> Block< + impl FnOnce(&'a RankedDifferentiableTagged, &'b [Differentiable]) -> Differentiable, + 2, +> +where + Tag: Clone, + A: NumLike + PartialOrd + Default, +{ + Block { + f: |t, theta: &'b [Differentiable]| -> Differentiable { + relu( + t, + &(theta[0].clone().attach_rank().unwrap()), + &(theta[1].clone().attach_rank().unwrap()), + ) + }, + ranks: [input_len, neuron_count], + } +} diff --git a/little_learner/src/ext.rs b/little_learner/src/ext.rs index 763ed9b..b231be7 100644 --- a/little_learner/src/ext.rs +++ b/little_learner/src/ext.rs @@ -430,7 +430,7 @@ mod tests { crate::decider::relu( &t, &RankedDifferentiable::of_slice(weights), - Scalar::make(bias.clone()), + Scalar::make(*bias), ) .clone_real_part() .into_inner(), diff --git a/little_learner/src/gradient_descent.rs b/little_learner/src/gradient_descent.rs index 5b3fae5..23bb89a 100644 --- a/little_learner/src/gradient_descent.rs +++ b/little_learner/src/gradient_descent.rs @@ -404,8 +404,11 @@ mod tests { .map(|x| x.into_inner()) .collect::>(); let fitted_theta1 = theta1.to_scalar().real_part().into_inner(); - assert_eq!(fitted_theta0, [3.9746454441720851, 1.9714549220774951]); - assert_eq!(fitted_theta1, 6.1645790482740361); + assert_eq!( + fitted_theta0, + [3.974_645_444_172_085, 1.971_454_922_077_495] + ); + assert_eq!(fitted_theta1, 6.164_579_048_274_036); } #[test] @@ -452,7 +455,7 @@ mod tests { let fitted_theta1 = theta1.to_scalar().real_part().into_inner(); assert_eq!( fitted_theta0, - [3.980_262_420_345_729_5, 1.977_071_898_301_443_9] + [3.980_262_420_345_729_5, 1.977_071_898_301_444] ); assert_eq!(fitted_theta1, 6.170_196_024_282_712_5); } diff --git a/little_learner/src/lib.rs b/little_learner/src/lib.rs index cdea666..615732e 100644 --- a/little_learner/src/lib.rs +++ b/little_learner/src/lib.rs @@ -3,6 +3,7 @@ #![feature(array_methods)] pub mod auto_diff; +pub mod block; pub mod decider; pub mod ext; pub mod gradient_descent; diff --git a/little_learner/src/smooth.rs b/little_learner/src/smooth.rs index 8f25e03..b15e7c6 100644 --- a/little_learner/src/smooth.rs +++ b/little_learner/src/smooth.rs @@ -106,7 +106,7 @@ mod test_smooth { ] .map(|x| hydrate(&x)); - let mut current = hydrate(&vec![0.8, 3.1, 2.2]); + let mut current = hydrate(&[0.8, 3.1, 2.2]); let mut output = Vec::with_capacity(inputs.len()); for input in inputs { current = smooth(decay.clone(), ¤t, &input); diff --git a/little_learner_app/Cargo.toml b/little_learner_app/Cargo.toml index f43683c..33ae039 100644 --- a/little_learner_app/Cargo.toml +++ b/little_learner_app/Cargo.toml @@ -10,3 +10,4 @@ immutable-chunkmap = "1.0.5" ordered-float = "3.6.0" little_learner = { path = "../little_learner" } rand = "0.8.5" +csv = "1.2.2" diff --git a/little_learner_app/src/iris.csv b/little_learner_app/src/iris.csv new file mode 100755 index 0000000..5c4316c --- /dev/null +++ b/little_learner_app/src/iris.csv @@ -0,0 +1,151 @@ +5.1,3.5,1.4,0.2,Iris-setosa +4.9,3.0,1.4,0.2,Iris-setosa +4.7,3.2,1.3,0.2,Iris-setosa +4.6,3.1,1.5,0.2,Iris-setosa +5.0,3.6,1.4,0.2,Iris-setosa +5.4,3.9,1.7,0.4,Iris-setosa +4.6,3.4,1.4,0.3,Iris-setosa +5.0,3.4,1.5,0.2,Iris-setosa +4.4,2.9,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.4,3.7,1.5,0.2,Iris-setosa +4.8,3.4,1.6,0.2,Iris-setosa +4.8,3.0,1.4,0.1,Iris-setosa +4.3,3.0,1.1,0.1,Iris-setosa +5.8,4.0,1.2,0.2,Iris-setosa +5.7,4.4,1.5,0.4,Iris-setosa +5.4,3.9,1.3,0.4,Iris-setosa +5.1,3.5,1.4,0.3,Iris-setosa +5.7,3.8,1.7,0.3,Iris-setosa +5.1,3.8,1.5,0.3,Iris-setosa +5.4,3.4,1.7,0.2,Iris-setosa +5.1,3.7,1.5,0.4,Iris-setosa +4.6,3.6,1.0,0.2,Iris-setosa +5.1,3.3,1.7,0.5,Iris-setosa +4.8,3.4,1.9,0.2,Iris-setosa +5.0,3.0,1.6,0.2,Iris-setosa +5.0,3.4,1.6,0.4,Iris-setosa +5.2,3.5,1.5,0.2,Iris-setosa +5.2,3.4,1.4,0.2,Iris-setosa +4.7,3.2,1.6,0.2,Iris-setosa +4.8,3.1,1.6,0.2,Iris-setosa +5.4,3.4,1.5,0.4,Iris-setosa +5.2,4.1,1.5,0.1,Iris-setosa +5.5,4.2,1.4,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +5.0,3.2,1.2,0.2,Iris-setosa +5.5,3.5,1.3,0.2,Iris-setosa +4.9,3.1,1.5,0.1,Iris-setosa +4.4,3.0,1.3,0.2,Iris-setosa +5.1,3.4,1.5,0.2,Iris-setosa +5.0,3.5,1.3,0.3,Iris-setosa +4.5,2.3,1.3,0.3,Iris-setosa +4.4,3.2,1.3,0.2,Iris-setosa +5.0,3.5,1.6,0.6,Iris-setosa +5.1,3.8,1.9,0.4,Iris-setosa +4.8,3.0,1.4,0.3,Iris-setosa +5.1,3.8,1.6,0.2,Iris-setosa +4.6,3.2,1.4,0.2,Iris-setosa +5.3,3.7,1.5,0.2,Iris-setosa +5.0,3.3,1.4,0.2,Iris-setosa +7.0,3.2,4.7,1.4,Iris-versicolor +6.4,3.2,4.5,1.5,Iris-versicolor +6.9,3.1,4.9,1.5,Iris-versicolor +5.5,2.3,4.0,1.3,Iris-versicolor +6.5,2.8,4.6,1.5,Iris-versicolor +5.7,2.8,4.5,1.3,Iris-versicolor +6.3,3.3,4.7,1.6,Iris-versicolor +4.9,2.4,3.3,1.0,Iris-versicolor +6.6,2.9,4.6,1.3,Iris-versicolor +5.2,2.7,3.9,1.4,Iris-versicolor +5.0,2.0,3.5,1.0,Iris-versicolor +5.9,3.0,4.2,1.5,Iris-versicolor +6.0,2.2,4.0,1.0,Iris-versicolor +6.1,2.9,4.7,1.4,Iris-versicolor +5.6,2.9,3.6,1.3,Iris-versicolor +6.7,3.1,4.4,1.4,Iris-versicolor +5.6,3.0,4.5,1.5,Iris-versicolor +5.8,2.7,4.1,1.0,Iris-versicolor +6.2,2.2,4.5,1.5,Iris-versicolor +5.6,2.5,3.9,1.1,Iris-versicolor +5.9,3.2,4.8,1.8,Iris-versicolor +6.1,2.8,4.0,1.3,Iris-versicolor +6.3,2.5,4.9,1.5,Iris-versicolor +6.1,2.8,4.7,1.2,Iris-versicolor +6.4,2.9,4.3,1.3,Iris-versicolor +6.6,3.0,4.4,1.4,Iris-versicolor +6.8,2.8,4.8,1.4,Iris-versicolor +6.7,3.0,5.0,1.7,Iris-versicolor +6.0,2.9,4.5,1.5,Iris-versicolor +5.7,2.6,3.5,1.0,Iris-versicolor +5.5,2.4,3.8,1.1,Iris-versicolor +5.5,2.4,3.7,1.0,Iris-versicolor +5.8,2.7,3.9,1.2,Iris-versicolor +6.0,2.7,5.1,1.6,Iris-versicolor +5.4,3.0,4.5,1.5,Iris-versicolor +6.0,3.4,4.5,1.6,Iris-versicolor +6.7,3.1,4.7,1.5,Iris-versicolor +6.3,2.3,4.4,1.3,Iris-versicolor +5.6,3.0,4.1,1.3,Iris-versicolor +5.5,2.5,4.0,1.3,Iris-versicolor +5.5,2.6,4.4,1.2,Iris-versicolor +6.1,3.0,4.6,1.4,Iris-versicolor +5.8,2.6,4.0,1.2,Iris-versicolor +5.0,2.3,3.3,1.0,Iris-versicolor +5.6,2.7,4.2,1.3,Iris-versicolor +5.7,3.0,4.2,1.2,Iris-versicolor +5.7,2.9,4.2,1.3,Iris-versicolor +6.2,2.9,4.3,1.3,Iris-versicolor +5.1,2.5,3.0,1.1,Iris-versicolor +5.7,2.8,4.1,1.3,Iris-versicolor +6.3,3.3,6.0,2.5,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +7.1,3.0,5.9,2.1,Iris-virginica +6.3,2.9,5.6,1.8,Iris-virginica +6.5,3.0,5.8,2.2,Iris-virginica +7.6,3.0,6.6,2.1,Iris-virginica +4.9,2.5,4.5,1.7,Iris-virginica +7.3,2.9,6.3,1.8,Iris-virginica +6.7,2.5,5.8,1.8,Iris-virginica +7.2,3.6,6.1,2.5,Iris-virginica +6.5,3.2,5.1,2.0,Iris-virginica +6.4,2.7,5.3,1.9,Iris-virginica +6.8,3.0,5.5,2.1,Iris-virginica +5.7,2.5,5.0,2.0,Iris-virginica +5.8,2.8,5.1,2.4,Iris-virginica +6.4,3.2,5.3,2.3,Iris-virginica +6.5,3.0,5.5,1.8,Iris-virginica +7.7,3.8,6.7,2.2,Iris-virginica +7.7,2.6,6.9,2.3,Iris-virginica +6.0,2.2,5.0,1.5,Iris-virginica +6.9,3.2,5.7,2.3,Iris-virginica +5.6,2.8,4.9,2.0,Iris-virginica +7.7,2.8,6.7,2.0,Iris-virginica +6.3,2.7,4.9,1.8,Iris-virginica +6.7,3.3,5.7,2.1,Iris-virginica +7.2,3.2,6.0,1.8,Iris-virginica +6.2,2.8,4.8,1.8,Iris-virginica +6.1,3.0,4.9,1.8,Iris-virginica +6.4,2.8,5.6,2.1,Iris-virginica +7.2,3.0,5.8,1.6,Iris-virginica +7.4,2.8,6.1,1.9,Iris-virginica +7.9,3.8,6.4,2.0,Iris-virginica +6.4,2.8,5.6,2.2,Iris-virginica +6.3,2.8,5.1,1.5,Iris-virginica +6.1,2.6,5.6,1.4,Iris-virginica +7.7,3.0,6.1,2.3,Iris-virginica +6.3,3.4,5.6,2.4,Iris-virginica +6.4,3.1,5.5,1.8,Iris-virginica +6.0,3.0,4.8,1.8,Iris-virginica +6.9,3.1,5.4,2.1,Iris-virginica +6.7,3.1,5.6,2.4,Iris-virginica +6.9,3.1,5.1,2.3,Iris-virginica +5.8,2.7,5.1,1.9,Iris-virginica +6.8,3.2,5.9,2.3,Iris-virginica +6.7,3.3,5.7,2.5,Iris-virginica +6.7,3.0,5.2,2.3,Iris-virginica +6.3,2.5,5.0,1.9,Iris-virginica +6.5,3.0,5.2,2.0,Iris-virginica +6.2,3.4,5.4,2.3,Iris-virginica +5.9,3.0,5.1,1.8,Iris-virginica + diff --git a/little_learner_app/src/iris.rs b/little_learner_app/src/iris.rs new file mode 100644 index 0000000..449a9cb --- /dev/null +++ b/little_learner_app/src/iris.rs @@ -0,0 +1,77 @@ +use csv::ReaderBuilder; +use std::io::Cursor; +use std::str::FromStr; + +const IRIS_DATA: &str = include_str!("iris.csv"); + +#[derive(Eq, PartialEq, Debug)] +pub enum IrisType { + Setosa, + Versicolor, + Virginica, +} + +impl FromStr for IrisType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "Iris-virginica" => Ok(IrisType::Virginica), + "Iris-versicolor" => Ok(IrisType::Versicolor), + "Iris-setosa" => Ok(IrisType::Setosa), + _ => Err(String::from(s)), + } + } +} + +#[derive(PartialEq, Debug)] +pub struct Iris { + pub class: IrisType, + pub petal_length: f32, + pub petal_width: f32, + pub sepal_length: f32, + pub sepal_width: f32, +} + +pub fn import() -> Vec { + let mut reader = ReaderBuilder::new() + .has_headers(false) + .from_reader(Cursor::new(IRIS_DATA)); + let mut output = Vec::new(); + for record in reader.records() { + let record = record.unwrap(); + let petal_length = f32::from_str(&record[0]).unwrap(); + let petal_width = f32::from_str(&record[1]).unwrap(); + let sepal_length = f32::from_str(&record[2]).unwrap(); + let sepal_width = f32::from_str(&record[3]).unwrap(); + let class = IrisType::from_str(&record[4]).unwrap(); + output.push(Iris { + class, + petal_length, + petal_width, + sepal_length, + sepal_width, + }); + } + + output +} + +pub(crate) const EXPECTED_FIRST: Iris = Iris { + class: IrisType::Setosa, + petal_length: 5.1, + petal_width: 3.5, + sepal_length: 1.4, + sepal_width: 0.2, +}; + +#[cfg(test)] +mod test { + use crate::iris::import; + + #[test] + fn first_element() { + let irises = import(); + assert_eq!(irises[0], crate::iris::EXPECTED_FIRST); + } +} diff --git a/little_learner_app/src/main.rs b/little_learner_app/src/main.rs index 399fc0a..ac779d9 100644 --- a/little_learner_app/src/main.rs +++ b/little_learner_app/src/main.rs @@ -1,71 +1,14 @@ #![allow(incomplete_features)] #![feature(generic_const_exprs)] -use little_learner::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged}; +use crate::rms_example::rms_example; -use little_learner::gradient_descent::gradient_descent; -use little_learner::hyper; -use little_learner::loss::predict_plane; -use little_learner::not_nan::{to_not_nan_1, to_not_nan_2}; -use little_learner::predictor; -use little_learner::scalar::Scalar; -use little_learner::traits::Zero; -use ordered_float::NotNan; - -const PLANE_XS: [[f64; 2]; 6] = [ - [1.0, 2.05], - [1.0, 3.0], - [2.0, 2.0], - [2.0, 3.91], - [3.0, 6.13], - [4.0, 8.09], -]; -const PLANE_YS: [f64; 6] = [13.99, 15.99, 18.0, 22.4, 30.2, 37.94]; +mod iris; +mod rms_example; fn main() { - let beta = NotNan::new(0.9).expect("not nan"); - let stabilizer = NotNan::new(0.000_000_01).expect("not nan"); - let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.01).expect("not nan"), 3000) - .with_stabilizer(stabilizer) - .with_beta(beta); + rms_example(); - let iterated = { - let xs = to_not_nan_2(PLANE_XS); - let ys = to_not_nan_1(PLANE_YS); - let zero_params = [ - RankedDifferentiable::of_slice(&[NotNan::::zero(), NotNan::::zero()]) - .to_unranked(), - Differentiable::of_scalar(Scalar::zero()), - ]; - - gradient_descent( - hyper, - &xs, - RankedDifferentiableTagged::of_slice_2::<_, 2>, - &ys, - zero_params, - predictor::rms(predict_plane), - hyper::RmsGradientDescent::to_immutable, - ) - }; - - let [theta0, theta1] = iterated; - - let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor"); - let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor"); - - let fitted_theta0 = theta0 - .collect() - .iter() - .map(|x| x.into_inner()) - .collect::>(); - let fitted_theta1 = theta1.to_scalar().real_part().into_inner(); - assert_eq!( - fitted_theta0, - [3.974_645_444_172_085, 1.971_454_922_077_495] - ); - assert_eq!(fitted_theta1, 6.164_579_048_274_036); + let irises = iris::import(); + assert_eq!(irises[0], crate::iris::EXPECTED_FIRST); } - -#[cfg(test)] -mod tests {} diff --git a/little_learner_app/src/rms_example.rs b/little_learner_app/src/rms_example.rs new file mode 100644 index 0000000..ca44dc3 --- /dev/null +++ b/little_learner_app/src/rms_example.rs @@ -0,0 +1,65 @@ +use little_learner::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged}; + +use little_learner::gradient_descent::gradient_descent; +use little_learner::hyper; +use little_learner::loss::predict_plane; +use little_learner::not_nan::{to_not_nan_1, to_not_nan_2}; +use little_learner::predictor; +use little_learner::scalar::Scalar; +use little_learner::traits::Zero; +use ordered_float::NotNan; + +const PLANE_XS: [[f64; 2]; 6] = [ + [1.0, 2.05], + [1.0, 3.0], + [2.0, 2.0], + [2.0, 3.91], + [3.0, 6.13], + [4.0, 8.09], +]; +const PLANE_YS: [f64; 6] = [13.99, 15.99, 18.0, 22.4, 30.2, 37.94]; + +pub(crate) fn rms_example() { + let beta = NotNan::new(0.9).expect("not nan"); + let stabilizer = NotNan::new(0.000_000_01).expect("not nan"); + let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.01).expect("not nan"), 3000) + .with_stabilizer(stabilizer) + .with_beta(beta); + + let iterated = { + let xs = to_not_nan_2(PLANE_XS); + let ys = to_not_nan_1(PLANE_YS); + let zero_params = [ + RankedDifferentiable::of_slice(&[NotNan::::zero(), NotNan::::zero()]) + .to_unranked(), + Differentiable::of_scalar(Scalar::zero()), + ]; + + gradient_descent( + hyper, + &xs, + RankedDifferentiableTagged::of_slice_2::<_, 2>, + &ys, + zero_params, + predictor::rms(predict_plane), + hyper::RmsGradientDescent::to_immutable, + ) + }; + + let [theta0, theta1] = iterated; + + let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor"); + let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor"); + + let fitted_theta0 = theta0 + .collect() + .iter() + .map(|x| x.into_inner()) + .collect::>(); + let fitted_theta1 = theta1.to_scalar().real_part().into_inner(); + assert_eq!( + fitted_theta0, + [3.974_645_444_172_085, 1.971_454_922_077_495] + ); + assert_eq!(fitted_theta1, 6.164_579_048_274_036); +}