Iris data (#29)
This commit is contained in:
46
Cargo.lock
generated
46
Cargo.lock
generated
@@ -32,6 +32,27 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "csv"
|
||||
version = "1.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "626ae34994d3d8d668f4269922248239db4ae42d538b14c398b74a52208e8086"
|
||||
dependencies = [
|
||||
"csv-core",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv-core"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "funty"
|
||||
version = "2.0.0"
|
||||
@@ -60,6 +81,12 @@ dependencies = [
|
||||
"packed_struct_codegen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.142"
|
||||
@@ -79,12 +106,19 @@ dependencies = [
|
||||
name = "little_learner_app"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"csv",
|
||||
"immutable-chunkmap",
|
||||
"little_learner",
|
||||
"ordered-float",
|
||||
"rand",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.15"
|
||||
@@ -184,6 +218,18 @@ dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.164"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
|
50
little_learner/src/block.rs
Normal file
50
little_learner/src/block.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
|
||||
use crate::ext::relu;
|
||||
use crate::traits::NumLike;
|
||||
|
||||
pub struct Block<F, const N: usize> {
|
||||
f: F,
|
||||
ranks: [usize; N],
|
||||
}
|
||||
|
||||
pub fn compose<'a, A, T, B, C, F, G, const N: usize, const M: usize>(
|
||||
b1: Block<F, N>,
|
||||
b2: Block<G, M>,
|
||||
j: usize,
|
||||
) -> Block<impl FnOnce(A, &'a [T]) -> C, { N + M }>
|
||||
where
|
||||
F: FnOnce(A, &'a [T]) -> B,
|
||||
G: FnOnce(B, &'a [T]) -> C,
|
||||
T: 'a,
|
||||
{
|
||||
let mut ranks = [0usize; N + M];
|
||||
ranks.copy_from_slice(&b1.ranks[..N]);
|
||||
ranks[N..(M + N)].copy_from_slice(&b2.ranks[..M]);
|
||||
Block {
|
||||
f: move |t, theta| (b2.f)((b1.f)(t, theta), &theta[j..]),
|
||||
ranks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dense<'a, 'b, A, Tag>(
|
||||
input_len: usize,
|
||||
neuron_count: usize,
|
||||
) -> Block<
|
||||
impl FnOnce(&'a RankedDifferentiableTagged<A, Tag, 1>, &'b [Differentiable<A>]) -> Differentiable<A>,
|
||||
2,
|
||||
>
|
||||
where
|
||||
Tag: Clone,
|
||||
A: NumLike + PartialOrd + Default,
|
||||
{
|
||||
Block {
|
||||
f: |t, theta: &'b [Differentiable<A>]| -> Differentiable<A> {
|
||||
relu(
|
||||
t,
|
||||
&(theta[0].clone().attach_rank().unwrap()),
|
||||
&(theta[1].clone().attach_rank().unwrap()),
|
||||
)
|
||||
},
|
||||
ranks: [input_len, neuron_count],
|
||||
}
|
||||
}
|
@@ -430,7 +430,7 @@ mod tests {
|
||||
crate::decider::relu(
|
||||
&t,
|
||||
&RankedDifferentiable::of_slice(weights),
|
||||
Scalar::make(bias.clone()),
|
||||
Scalar::make(*bias),
|
||||
)
|
||||
.clone_real_part()
|
||||
.into_inner(),
|
||||
|
@@ -404,8 +404,11 @@ mod tests {
|
||||
.map(|x| x.into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(fitted_theta0, [3.9746454441720851, 1.9714549220774951]);
|
||||
assert_eq!(fitted_theta1, 6.1645790482740361);
|
||||
assert_eq!(
|
||||
fitted_theta0,
|
||||
[3.974_645_444_172_085, 1.971_454_922_077_495]
|
||||
);
|
||||
assert_eq!(fitted_theta1, 6.164_579_048_274_036);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -452,7 +455,7 @@ mod tests {
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(
|
||||
fitted_theta0,
|
||||
[3.980_262_420_345_729_5, 1.977_071_898_301_443_9]
|
||||
[3.980_262_420_345_729_5, 1.977_071_898_301_444]
|
||||
);
|
||||
assert_eq!(fitted_theta1, 6.170_196_024_282_712_5);
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#![feature(array_methods)]
|
||||
|
||||
pub mod auto_diff;
|
||||
pub mod block;
|
||||
pub mod decider;
|
||||
pub mod ext;
|
||||
pub mod gradient_descent;
|
||||
|
@@ -106,7 +106,7 @@ mod test_smooth {
|
||||
]
|
||||
.map(|x| hydrate(&x));
|
||||
|
||||
let mut current = hydrate(&vec![0.8, 3.1, 2.2]);
|
||||
let mut current = hydrate(&[0.8, 3.1, 2.2]);
|
||||
let mut output = Vec::with_capacity(inputs.len());
|
||||
for input in inputs {
|
||||
current = smooth(decay.clone(), ¤t, &input);
|
||||
|
@@ -10,3 +10,4 @@ immutable-chunkmap = "1.0.5"
|
||||
ordered-float = "3.6.0"
|
||||
little_learner = { path = "../little_learner" }
|
||||
rand = "0.8.5"
|
||||
csv = "1.2.2"
|
||||
|
151
little_learner_app/src/iris.csv
Executable file
151
little_learner_app/src/iris.csv
Executable file
@@ -0,0 +1,151 @@
|
||||
5.1,3.5,1.4,0.2,Iris-setosa
|
||||
4.9,3.0,1.4,0.2,Iris-setosa
|
||||
4.7,3.2,1.3,0.2,Iris-setosa
|
||||
4.6,3.1,1.5,0.2,Iris-setosa
|
||||
5.0,3.6,1.4,0.2,Iris-setosa
|
||||
5.4,3.9,1.7,0.4,Iris-setosa
|
||||
4.6,3.4,1.4,0.3,Iris-setosa
|
||||
5.0,3.4,1.5,0.2,Iris-setosa
|
||||
4.4,2.9,1.4,0.2,Iris-setosa
|
||||
4.9,3.1,1.5,0.1,Iris-setosa
|
||||
5.4,3.7,1.5,0.2,Iris-setosa
|
||||
4.8,3.4,1.6,0.2,Iris-setosa
|
||||
4.8,3.0,1.4,0.1,Iris-setosa
|
||||
4.3,3.0,1.1,0.1,Iris-setosa
|
||||
5.8,4.0,1.2,0.2,Iris-setosa
|
||||
5.7,4.4,1.5,0.4,Iris-setosa
|
||||
5.4,3.9,1.3,0.4,Iris-setosa
|
||||
5.1,3.5,1.4,0.3,Iris-setosa
|
||||
5.7,3.8,1.7,0.3,Iris-setosa
|
||||
5.1,3.8,1.5,0.3,Iris-setosa
|
||||
5.4,3.4,1.7,0.2,Iris-setosa
|
||||
5.1,3.7,1.5,0.4,Iris-setosa
|
||||
4.6,3.6,1.0,0.2,Iris-setosa
|
||||
5.1,3.3,1.7,0.5,Iris-setosa
|
||||
4.8,3.4,1.9,0.2,Iris-setosa
|
||||
5.0,3.0,1.6,0.2,Iris-setosa
|
||||
5.0,3.4,1.6,0.4,Iris-setosa
|
||||
5.2,3.5,1.5,0.2,Iris-setosa
|
||||
5.2,3.4,1.4,0.2,Iris-setosa
|
||||
4.7,3.2,1.6,0.2,Iris-setosa
|
||||
4.8,3.1,1.6,0.2,Iris-setosa
|
||||
5.4,3.4,1.5,0.4,Iris-setosa
|
||||
5.2,4.1,1.5,0.1,Iris-setosa
|
||||
5.5,4.2,1.4,0.2,Iris-setosa
|
||||
4.9,3.1,1.5,0.1,Iris-setosa
|
||||
5.0,3.2,1.2,0.2,Iris-setosa
|
||||
5.5,3.5,1.3,0.2,Iris-setosa
|
||||
4.9,3.1,1.5,0.1,Iris-setosa
|
||||
4.4,3.0,1.3,0.2,Iris-setosa
|
||||
5.1,3.4,1.5,0.2,Iris-setosa
|
||||
5.0,3.5,1.3,0.3,Iris-setosa
|
||||
4.5,2.3,1.3,0.3,Iris-setosa
|
||||
4.4,3.2,1.3,0.2,Iris-setosa
|
||||
5.0,3.5,1.6,0.6,Iris-setosa
|
||||
5.1,3.8,1.9,0.4,Iris-setosa
|
||||
4.8,3.0,1.4,0.3,Iris-setosa
|
||||
5.1,3.8,1.6,0.2,Iris-setosa
|
||||
4.6,3.2,1.4,0.2,Iris-setosa
|
||||
5.3,3.7,1.5,0.2,Iris-setosa
|
||||
5.0,3.3,1.4,0.2,Iris-setosa
|
||||
7.0,3.2,4.7,1.4,Iris-versicolor
|
||||
6.4,3.2,4.5,1.5,Iris-versicolor
|
||||
6.9,3.1,4.9,1.5,Iris-versicolor
|
||||
5.5,2.3,4.0,1.3,Iris-versicolor
|
||||
6.5,2.8,4.6,1.5,Iris-versicolor
|
||||
5.7,2.8,4.5,1.3,Iris-versicolor
|
||||
6.3,3.3,4.7,1.6,Iris-versicolor
|
||||
4.9,2.4,3.3,1.0,Iris-versicolor
|
||||
6.6,2.9,4.6,1.3,Iris-versicolor
|
||||
5.2,2.7,3.9,1.4,Iris-versicolor
|
||||
5.0,2.0,3.5,1.0,Iris-versicolor
|
||||
5.9,3.0,4.2,1.5,Iris-versicolor
|
||||
6.0,2.2,4.0,1.0,Iris-versicolor
|
||||
6.1,2.9,4.7,1.4,Iris-versicolor
|
||||
5.6,2.9,3.6,1.3,Iris-versicolor
|
||||
6.7,3.1,4.4,1.4,Iris-versicolor
|
||||
5.6,3.0,4.5,1.5,Iris-versicolor
|
||||
5.8,2.7,4.1,1.0,Iris-versicolor
|
||||
6.2,2.2,4.5,1.5,Iris-versicolor
|
||||
5.6,2.5,3.9,1.1,Iris-versicolor
|
||||
5.9,3.2,4.8,1.8,Iris-versicolor
|
||||
6.1,2.8,4.0,1.3,Iris-versicolor
|
||||
6.3,2.5,4.9,1.5,Iris-versicolor
|
||||
6.1,2.8,4.7,1.2,Iris-versicolor
|
||||
6.4,2.9,4.3,1.3,Iris-versicolor
|
||||
6.6,3.0,4.4,1.4,Iris-versicolor
|
||||
6.8,2.8,4.8,1.4,Iris-versicolor
|
||||
6.7,3.0,5.0,1.7,Iris-versicolor
|
||||
6.0,2.9,4.5,1.5,Iris-versicolor
|
||||
5.7,2.6,3.5,1.0,Iris-versicolor
|
||||
5.5,2.4,3.8,1.1,Iris-versicolor
|
||||
5.5,2.4,3.7,1.0,Iris-versicolor
|
||||
5.8,2.7,3.9,1.2,Iris-versicolor
|
||||
6.0,2.7,5.1,1.6,Iris-versicolor
|
||||
5.4,3.0,4.5,1.5,Iris-versicolor
|
||||
6.0,3.4,4.5,1.6,Iris-versicolor
|
||||
6.7,3.1,4.7,1.5,Iris-versicolor
|
||||
6.3,2.3,4.4,1.3,Iris-versicolor
|
||||
5.6,3.0,4.1,1.3,Iris-versicolor
|
||||
5.5,2.5,4.0,1.3,Iris-versicolor
|
||||
5.5,2.6,4.4,1.2,Iris-versicolor
|
||||
6.1,3.0,4.6,1.4,Iris-versicolor
|
||||
5.8,2.6,4.0,1.2,Iris-versicolor
|
||||
5.0,2.3,3.3,1.0,Iris-versicolor
|
||||
5.6,2.7,4.2,1.3,Iris-versicolor
|
||||
5.7,3.0,4.2,1.2,Iris-versicolor
|
||||
5.7,2.9,4.2,1.3,Iris-versicolor
|
||||
6.2,2.9,4.3,1.3,Iris-versicolor
|
||||
5.1,2.5,3.0,1.1,Iris-versicolor
|
||||
5.7,2.8,4.1,1.3,Iris-versicolor
|
||||
6.3,3.3,6.0,2.5,Iris-virginica
|
||||
5.8,2.7,5.1,1.9,Iris-virginica
|
||||
7.1,3.0,5.9,2.1,Iris-virginica
|
||||
6.3,2.9,5.6,1.8,Iris-virginica
|
||||
6.5,3.0,5.8,2.2,Iris-virginica
|
||||
7.6,3.0,6.6,2.1,Iris-virginica
|
||||
4.9,2.5,4.5,1.7,Iris-virginica
|
||||
7.3,2.9,6.3,1.8,Iris-virginica
|
||||
6.7,2.5,5.8,1.8,Iris-virginica
|
||||
7.2,3.6,6.1,2.5,Iris-virginica
|
||||
6.5,3.2,5.1,2.0,Iris-virginica
|
||||
6.4,2.7,5.3,1.9,Iris-virginica
|
||||
6.8,3.0,5.5,2.1,Iris-virginica
|
||||
5.7,2.5,5.0,2.0,Iris-virginica
|
||||
5.8,2.8,5.1,2.4,Iris-virginica
|
||||
6.4,3.2,5.3,2.3,Iris-virginica
|
||||
6.5,3.0,5.5,1.8,Iris-virginica
|
||||
7.7,3.8,6.7,2.2,Iris-virginica
|
||||
7.7,2.6,6.9,2.3,Iris-virginica
|
||||
6.0,2.2,5.0,1.5,Iris-virginica
|
||||
6.9,3.2,5.7,2.3,Iris-virginica
|
||||
5.6,2.8,4.9,2.0,Iris-virginica
|
||||
7.7,2.8,6.7,2.0,Iris-virginica
|
||||
6.3,2.7,4.9,1.8,Iris-virginica
|
||||
6.7,3.3,5.7,2.1,Iris-virginica
|
||||
7.2,3.2,6.0,1.8,Iris-virginica
|
||||
6.2,2.8,4.8,1.8,Iris-virginica
|
||||
6.1,3.0,4.9,1.8,Iris-virginica
|
||||
6.4,2.8,5.6,2.1,Iris-virginica
|
||||
7.2,3.0,5.8,1.6,Iris-virginica
|
||||
7.4,2.8,6.1,1.9,Iris-virginica
|
||||
7.9,3.8,6.4,2.0,Iris-virginica
|
||||
6.4,2.8,5.6,2.2,Iris-virginica
|
||||
6.3,2.8,5.1,1.5,Iris-virginica
|
||||
6.1,2.6,5.6,1.4,Iris-virginica
|
||||
7.7,3.0,6.1,2.3,Iris-virginica
|
||||
6.3,3.4,5.6,2.4,Iris-virginica
|
||||
6.4,3.1,5.5,1.8,Iris-virginica
|
||||
6.0,3.0,4.8,1.8,Iris-virginica
|
||||
6.9,3.1,5.4,2.1,Iris-virginica
|
||||
6.7,3.1,5.6,2.4,Iris-virginica
|
||||
6.9,3.1,5.1,2.3,Iris-virginica
|
||||
5.8,2.7,5.1,1.9,Iris-virginica
|
||||
6.8,3.2,5.9,2.3,Iris-virginica
|
||||
6.7,3.3,5.7,2.5,Iris-virginica
|
||||
6.7,3.0,5.2,2.3,Iris-virginica
|
||||
6.3,2.5,5.0,1.9,Iris-virginica
|
||||
6.5,3.0,5.2,2.0,Iris-virginica
|
||||
6.2,3.4,5.4,2.3,Iris-virginica
|
||||
5.9,3.0,5.1,1.8,Iris-virginica
|
||||
|
|
77
little_learner_app/src/iris.rs
Normal file
77
little_learner_app/src/iris.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use csv::ReaderBuilder;
|
||||
use std::io::Cursor;
|
||||
use std::str::FromStr;
|
||||
|
||||
const IRIS_DATA: &str = include_str!("iris.csv");
|
||||
|
||||
#[derive(Eq, PartialEq, Debug)]
|
||||
pub enum IrisType {
|
||||
Setosa,
|
||||
Versicolor,
|
||||
Virginica,
|
||||
}
|
||||
|
||||
impl FromStr for IrisType {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"Iris-virginica" => Ok(IrisType::Virginica),
|
||||
"Iris-versicolor" => Ok(IrisType::Versicolor),
|
||||
"Iris-setosa" => Ok(IrisType::Setosa),
|
||||
_ => Err(String::from(s)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
pub struct Iris {
|
||||
pub class: IrisType,
|
||||
pub petal_length: f32,
|
||||
pub petal_width: f32,
|
||||
pub sepal_length: f32,
|
||||
pub sepal_width: f32,
|
||||
}
|
||||
|
||||
pub fn import() -> Vec<Iris> {
|
||||
let mut reader = ReaderBuilder::new()
|
||||
.has_headers(false)
|
||||
.from_reader(Cursor::new(IRIS_DATA));
|
||||
let mut output = Vec::new();
|
||||
for record in reader.records() {
|
||||
let record = record.unwrap();
|
||||
let petal_length = f32::from_str(&record[0]).unwrap();
|
||||
let petal_width = f32::from_str(&record[1]).unwrap();
|
||||
let sepal_length = f32::from_str(&record[2]).unwrap();
|
||||
let sepal_width = f32::from_str(&record[3]).unwrap();
|
||||
let class = IrisType::from_str(&record[4]).unwrap();
|
||||
output.push(Iris {
|
||||
class,
|
||||
petal_length,
|
||||
petal_width,
|
||||
sepal_length,
|
||||
sepal_width,
|
||||
});
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) const EXPECTED_FIRST: Iris = Iris {
|
||||
class: IrisType::Setosa,
|
||||
petal_length: 5.1,
|
||||
petal_width: 3.5,
|
||||
sepal_length: 1.4,
|
||||
sepal_width: 0.2,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::iris::import;
|
||||
|
||||
#[test]
|
||||
fn first_element() {
|
||||
let irises = import();
|
||||
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST);
|
||||
}
|
||||
}
|
@@ -1,71 +1,14 @@
|
||||
#![allow(incomplete_features)]
|
||||
#![feature(generic_const_exprs)]
|
||||
|
||||
use little_learner::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
|
||||
use crate::rms_example::rms_example;
|
||||
|
||||
use little_learner::gradient_descent::gradient_descent;
|
||||
use little_learner::hyper;
|
||||
use little_learner::loss::predict_plane;
|
||||
use little_learner::not_nan::{to_not_nan_1, to_not_nan_2};
|
||||
use little_learner::predictor;
|
||||
use little_learner::scalar::Scalar;
|
||||
use little_learner::traits::Zero;
|
||||
use ordered_float::NotNan;
|
||||
|
||||
const PLANE_XS: [[f64; 2]; 6] = [
|
||||
[1.0, 2.05],
|
||||
[1.0, 3.0],
|
||||
[2.0, 2.0],
|
||||
[2.0, 3.91],
|
||||
[3.0, 6.13],
|
||||
[4.0, 8.09],
|
||||
];
|
||||
const PLANE_YS: [f64; 6] = [13.99, 15.99, 18.0, 22.4, 30.2, 37.94];
|
||||
mod iris;
|
||||
mod rms_example;
|
||||
|
||||
fn main() {
|
||||
let beta = NotNan::new(0.9).expect("not nan");
|
||||
let stabilizer = NotNan::new(0.000_000_01).expect("not nan");
|
||||
let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.01).expect("not nan"), 3000)
|
||||
.with_stabilizer(stabilizer)
|
||||
.with_beta(beta);
|
||||
rms_example();
|
||||
|
||||
let iterated = {
|
||||
let xs = to_not_nan_2(PLANE_XS);
|
||||
let ys = to_not_nan_1(PLANE_YS);
|
||||
let zero_params = [
|
||||
RankedDifferentiable::of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()])
|
||||
.to_unranked(),
|
||||
Differentiable::of_scalar(Scalar::zero()),
|
||||
];
|
||||
|
||||
gradient_descent(
|
||||
hyper,
|
||||
&xs,
|
||||
RankedDifferentiableTagged::of_slice_2::<_, 2>,
|
||||
&ys,
|
||||
zero_params,
|
||||
predictor::rms(predict_plane),
|
||||
hyper::RmsGradientDescent::to_immutable,
|
||||
)
|
||||
};
|
||||
|
||||
let [theta0, theta1] = iterated;
|
||||
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
|
||||
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
|
||||
|
||||
let fitted_theta0 = theta0
|
||||
.collect()
|
||||
.iter()
|
||||
.map(|x| x.into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(
|
||||
fitted_theta0,
|
||||
[3.974_645_444_172_085, 1.971_454_922_077_495]
|
||||
);
|
||||
assert_eq!(fitted_theta1, 6.164_579_048_274_036);
|
||||
let irises = iris::import();
|
||||
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {}
|
||||
|
65
little_learner_app/src/rms_example.rs
Normal file
65
little_learner_app/src/rms_example.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use little_learner::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
|
||||
|
||||
use little_learner::gradient_descent::gradient_descent;
|
||||
use little_learner::hyper;
|
||||
use little_learner::loss::predict_plane;
|
||||
use little_learner::not_nan::{to_not_nan_1, to_not_nan_2};
|
||||
use little_learner::predictor;
|
||||
use little_learner::scalar::Scalar;
|
||||
use little_learner::traits::Zero;
|
||||
use ordered_float::NotNan;
|
||||
|
||||
const PLANE_XS: [[f64; 2]; 6] = [
|
||||
[1.0, 2.05],
|
||||
[1.0, 3.0],
|
||||
[2.0, 2.0],
|
||||
[2.0, 3.91],
|
||||
[3.0, 6.13],
|
||||
[4.0, 8.09],
|
||||
];
|
||||
const PLANE_YS: [f64; 6] = [13.99, 15.99, 18.0, 22.4, 30.2, 37.94];
|
||||
|
||||
pub(crate) fn rms_example() {
|
||||
let beta = NotNan::new(0.9).expect("not nan");
|
||||
let stabilizer = NotNan::new(0.000_000_01).expect("not nan");
|
||||
let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.01).expect("not nan"), 3000)
|
||||
.with_stabilizer(stabilizer)
|
||||
.with_beta(beta);
|
||||
|
||||
let iterated = {
|
||||
let xs = to_not_nan_2(PLANE_XS);
|
||||
let ys = to_not_nan_1(PLANE_YS);
|
||||
let zero_params = [
|
||||
RankedDifferentiable::of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()])
|
||||
.to_unranked(),
|
||||
Differentiable::of_scalar(Scalar::zero()),
|
||||
];
|
||||
|
||||
gradient_descent(
|
||||
hyper,
|
||||
&xs,
|
||||
RankedDifferentiableTagged::of_slice_2::<_, 2>,
|
||||
&ys,
|
||||
zero_params,
|
||||
predictor::rms(predict_plane),
|
||||
hyper::RmsGradientDescent::to_immutable,
|
||||
)
|
||||
};
|
||||
|
||||
let [theta0, theta1] = iterated;
|
||||
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
|
||||
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
|
||||
|
||||
let fitted_theta0 = theta0
|
||||
.collect()
|
||||
.iter()
|
||||
.map(|x| x.into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(
|
||||
fitted_theta0,
|
||||
[3.974_645_444_172_085, 1.971_454_922_077_495]
|
||||
);
|
||||
assert_eq!(fitted_theta1, 6.164_579_048_274_036);
|
||||
}
|
Reference in New Issue
Block a user