Iris data (#29)

This commit is contained in:
Patrick Stevens
2023-06-17 19:03:01 +01:00
committed by GitHub
parent 095a8af7f2
commit fd55cd1c5f
11 changed files with 405 additions and 68 deletions

46
Cargo.lock generated
View File

@@ -32,6 +32,27 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "csv"
version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "626ae34994d3d8d668f4269922248239db4ae42d538b14c398b74a52208e8086"
dependencies = [
"csv-core",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "csv-core"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b2466559f260f48ad25fe6317b3c8dac77b5bdb5763ac7d9d6103530663bc90"
dependencies = [
"memchr",
]
[[package]]
name = "funty"
version = "2.0.0"
@@ -60,6 +81,12 @@ dependencies = [
"packed_struct_codegen",
]
[[package]]
name = "itoa"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6"
[[package]]
name = "libc"
version = "0.2.142"
@@ -79,12 +106,19 @@ dependencies = [
name = "little_learner_app"
version = "0.1.0"
dependencies = [
"csv",
"immutable-chunkmap",
"little_learner",
"ordered-float",
"rand",
]
[[package]]
name = "memchr"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
[[package]]
name = "num-traits"
version = "0.2.15"
@@ -184,6 +218,18 @@ dependencies = [
"getrandom",
]
[[package]]
name = "ryu"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041"
[[package]]
name = "serde"
version = "1.0.164"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d"
[[package]]
name = "syn"
version = "1.0.109"

View File

@@ -0,0 +1,50 @@
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
use crate::ext::relu;
use crate::traits::NumLike;
pub struct Block<F, const N: usize> {
f: F,
ranks: [usize; N],
}
pub fn compose<'a, A, T, B, C, F, G, const N: usize, const M: usize>(
b1: Block<F, N>,
b2: Block<G, M>,
j: usize,
) -> Block<impl FnOnce(A, &'a [T]) -> C, { N + M }>
where
F: FnOnce(A, &'a [T]) -> B,
G: FnOnce(B, &'a [T]) -> C,
T: 'a,
{
let mut ranks = [0usize; N + M];
ranks.copy_from_slice(&b1.ranks[..N]);
ranks[N..(M + N)].copy_from_slice(&b2.ranks[..M]);
Block {
f: move |t, theta| (b2.f)((b1.f)(t, theta), &theta[j..]),
ranks,
}
}
pub fn dense<'a, 'b, A, Tag>(
input_len: usize,
neuron_count: usize,
) -> Block<
impl FnOnce(&'a RankedDifferentiableTagged<A, Tag, 1>, &'b [Differentiable<A>]) -> Differentiable<A>,
2,
>
where
Tag: Clone,
A: NumLike + PartialOrd + Default,
{
Block {
f: |t, theta: &'b [Differentiable<A>]| -> Differentiable<A> {
relu(
t,
&(theta[0].clone().attach_rank().unwrap()),
&(theta[1].clone().attach_rank().unwrap()),
)
},
ranks: [input_len, neuron_count],
}
}

View File

@@ -430,7 +430,7 @@ mod tests {
crate::decider::relu(
&t,
&RankedDifferentiable::of_slice(weights),
Scalar::make(bias.clone()),
Scalar::make(*bias),
)
.clone_real_part()
.into_inner(),

View File

@@ -404,8 +404,11 @@ mod tests {
.map(|x| x.into_inner())
.collect::<Vec<_>>();
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
assert_eq!(fitted_theta0, [3.9746454441720851, 1.9714549220774951]);
assert_eq!(fitted_theta1, 6.1645790482740361);
assert_eq!(
fitted_theta0,
[3.974_645_444_172_085, 1.971_454_922_077_495]
);
assert_eq!(fitted_theta1, 6.164_579_048_274_036);
}
#[test]
@@ -452,7 +455,7 @@ mod tests {
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
assert_eq!(
fitted_theta0,
[3.980_262_420_345_729_5, 1.977_071_898_301_443_9]
[3.980_262_420_345_729_5, 1.977_071_898_301_444]
);
assert_eq!(fitted_theta1, 6.170_196_024_282_712_5);
}

View File

@@ -3,6 +3,7 @@
#![feature(array_methods)]
pub mod auto_diff;
pub mod block;
pub mod decider;
pub mod ext;
pub mod gradient_descent;

View File

@@ -106,7 +106,7 @@ mod test_smooth {
]
.map(|x| hydrate(&x));
let mut current = hydrate(&vec![0.8, 3.1, 2.2]);
let mut current = hydrate(&[0.8, 3.1, 2.2]);
let mut output = Vec::with_capacity(inputs.len());
for input in inputs {
current = smooth(decay.clone(), &current, &input);

View File

@@ -10,3 +10,4 @@ immutable-chunkmap = "1.0.5"
ordered-float = "3.6.0"
little_learner = { path = "../little_learner" }
rand = "0.8.5"
csv = "1.2.2"

151
little_learner_app/src/iris.csv Executable file
View File

@@ -0,0 +1,151 @@
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica
1 5.1 3.5 1.4 0.2 Iris-setosa
2 4.9 3.0 1.4 0.2 Iris-setosa
3 4.7 3.2 1.3 0.2 Iris-setosa
4 4.6 3.1 1.5 0.2 Iris-setosa
5 5.0 3.6 1.4 0.2 Iris-setosa
6 5.4 3.9 1.7 0.4 Iris-setosa
7 4.6 3.4 1.4 0.3 Iris-setosa
8 5.0 3.4 1.5 0.2 Iris-setosa
9 4.4 2.9 1.4 0.2 Iris-setosa
10 4.9 3.1 1.5 0.1 Iris-setosa
11 5.4 3.7 1.5 0.2 Iris-setosa
12 4.8 3.4 1.6 0.2 Iris-setosa
13 4.8 3.0 1.4 0.1 Iris-setosa
14 4.3 3.0 1.1 0.1 Iris-setosa
15 5.8 4.0 1.2 0.2 Iris-setosa
16 5.7 4.4 1.5 0.4 Iris-setosa
17 5.4 3.9 1.3 0.4 Iris-setosa
18 5.1 3.5 1.4 0.3 Iris-setosa
19 5.7 3.8 1.7 0.3 Iris-setosa
20 5.1 3.8 1.5 0.3 Iris-setosa
21 5.4 3.4 1.7 0.2 Iris-setosa
22 5.1 3.7 1.5 0.4 Iris-setosa
23 4.6 3.6 1.0 0.2 Iris-setosa
24 5.1 3.3 1.7 0.5 Iris-setosa
25 4.8 3.4 1.9 0.2 Iris-setosa
26 5.0 3.0 1.6 0.2 Iris-setosa
27 5.0 3.4 1.6 0.4 Iris-setosa
28 5.2 3.5 1.5 0.2 Iris-setosa
29 5.2 3.4 1.4 0.2 Iris-setosa
30 4.7 3.2 1.6 0.2 Iris-setosa
31 4.8 3.1 1.6 0.2 Iris-setosa
32 5.4 3.4 1.5 0.4 Iris-setosa
33 5.2 4.1 1.5 0.1 Iris-setosa
34 5.5 4.2 1.4 0.2 Iris-setosa
35 4.9 3.1 1.5 0.1 Iris-setosa
36 5.0 3.2 1.2 0.2 Iris-setosa
37 5.5 3.5 1.3 0.2 Iris-setosa
38 4.9 3.1 1.5 0.1 Iris-setosa
39 4.4 3.0 1.3 0.2 Iris-setosa
40 5.1 3.4 1.5 0.2 Iris-setosa
41 5.0 3.5 1.3 0.3 Iris-setosa
42 4.5 2.3 1.3 0.3 Iris-setosa
43 4.4 3.2 1.3 0.2 Iris-setosa
44 5.0 3.5 1.6 0.6 Iris-setosa
45 5.1 3.8 1.9 0.4 Iris-setosa
46 4.8 3.0 1.4 0.3 Iris-setosa
47 5.1 3.8 1.6 0.2 Iris-setosa
48 4.6 3.2 1.4 0.2 Iris-setosa
49 5.3 3.7 1.5 0.2 Iris-setosa
50 5.0 3.3 1.4 0.2 Iris-setosa
51 7.0 3.2 4.7 1.4 Iris-versicolor
52 6.4 3.2 4.5 1.5 Iris-versicolor
53 6.9 3.1 4.9 1.5 Iris-versicolor
54 5.5 2.3 4.0 1.3 Iris-versicolor
55 6.5 2.8 4.6 1.5 Iris-versicolor
56 5.7 2.8 4.5 1.3 Iris-versicolor
57 6.3 3.3 4.7 1.6 Iris-versicolor
58 4.9 2.4 3.3 1.0 Iris-versicolor
59 6.6 2.9 4.6 1.3 Iris-versicolor
60 5.2 2.7 3.9 1.4 Iris-versicolor
61 5.0 2.0 3.5 1.0 Iris-versicolor
62 5.9 3.0 4.2 1.5 Iris-versicolor
63 6.0 2.2 4.0 1.0 Iris-versicolor
64 6.1 2.9 4.7 1.4 Iris-versicolor
65 5.6 2.9 3.6 1.3 Iris-versicolor
66 6.7 3.1 4.4 1.4 Iris-versicolor
67 5.6 3.0 4.5 1.5 Iris-versicolor
68 5.8 2.7 4.1 1.0 Iris-versicolor
69 6.2 2.2 4.5 1.5 Iris-versicolor
70 5.6 2.5 3.9 1.1 Iris-versicolor
71 5.9 3.2 4.8 1.8 Iris-versicolor
72 6.1 2.8 4.0 1.3 Iris-versicolor
73 6.3 2.5 4.9 1.5 Iris-versicolor
74 6.1 2.8 4.7 1.2 Iris-versicolor
75 6.4 2.9 4.3 1.3 Iris-versicolor
76 6.6 3.0 4.4 1.4 Iris-versicolor
77 6.8 2.8 4.8 1.4 Iris-versicolor
78 6.7 3.0 5.0 1.7 Iris-versicolor
79 6.0 2.9 4.5 1.5 Iris-versicolor
80 5.7 2.6 3.5 1.0 Iris-versicolor
81 5.5 2.4 3.8 1.1 Iris-versicolor
82 5.5 2.4 3.7 1.0 Iris-versicolor
83 5.8 2.7 3.9 1.2 Iris-versicolor
84 6.0 2.7 5.1 1.6 Iris-versicolor
85 5.4 3.0 4.5 1.5 Iris-versicolor
86 6.0 3.4 4.5 1.6 Iris-versicolor
87 6.7 3.1 4.7 1.5 Iris-versicolor
88 6.3 2.3 4.4 1.3 Iris-versicolor
89 5.6 3.0 4.1 1.3 Iris-versicolor
90 5.5 2.5 4.0 1.3 Iris-versicolor
91 5.5 2.6 4.4 1.2 Iris-versicolor
92 6.1 3.0 4.6 1.4 Iris-versicolor
93 5.8 2.6 4.0 1.2 Iris-versicolor
94 5.0 2.3 3.3 1.0 Iris-versicolor
95 5.6 2.7 4.2 1.3 Iris-versicolor
96 5.7 3.0 4.2 1.2 Iris-versicolor
97 5.7 2.9 4.2 1.3 Iris-versicolor
98 6.2 2.9 4.3 1.3 Iris-versicolor
99 5.1 2.5 3.0 1.1 Iris-versicolor
100 5.7 2.8 4.1 1.3 Iris-versicolor
101 6.3 3.3 6.0 2.5 Iris-virginica
102 5.8 2.7 5.1 1.9 Iris-virginica
103 7.1 3.0 5.9 2.1 Iris-virginica
104 6.3 2.9 5.6 1.8 Iris-virginica
105 6.5 3.0 5.8 2.2 Iris-virginica
106 7.6 3.0 6.6 2.1 Iris-virginica
107 4.9 2.5 4.5 1.7 Iris-virginica
108 7.3 2.9 6.3 1.8 Iris-virginica
109 6.7 2.5 5.8 1.8 Iris-virginica
110 7.2 3.6 6.1 2.5 Iris-virginica
111 6.5 3.2 5.1 2.0 Iris-virginica
112 6.4 2.7 5.3 1.9 Iris-virginica
113 6.8 3.0 5.5 2.1 Iris-virginica
114 5.7 2.5 5.0 2.0 Iris-virginica
115 5.8 2.8 5.1 2.4 Iris-virginica
116 6.4 3.2 5.3 2.3 Iris-virginica
117 6.5 3.0 5.5 1.8 Iris-virginica
118 7.7 3.8 6.7 2.2 Iris-virginica
119 7.7 2.6 6.9 2.3 Iris-virginica
120 6.0 2.2 5.0 1.5 Iris-virginica
121 6.9 3.2 5.7 2.3 Iris-virginica
122 5.6 2.8 4.9 2.0 Iris-virginica
123 7.7 2.8 6.7 2.0 Iris-virginica
124 6.3 2.7 4.9 1.8 Iris-virginica
125 6.7 3.3 5.7 2.1 Iris-virginica
126 7.2 3.2 6.0 1.8 Iris-virginica
127 6.2 2.8 4.8 1.8 Iris-virginica
128 6.1 3.0 4.9 1.8 Iris-virginica
129 6.4 2.8 5.6 2.1 Iris-virginica
130 7.2 3.0 5.8 1.6 Iris-virginica
131 7.4 2.8 6.1 1.9 Iris-virginica
132 7.9 3.8 6.4 2.0 Iris-virginica
133 6.4 2.8 5.6 2.2 Iris-virginica
134 6.3 2.8 5.1 1.5 Iris-virginica
135 6.1 2.6 5.6 1.4 Iris-virginica
136 7.7 3.0 6.1 2.3 Iris-virginica
137 6.3 3.4 5.6 2.4 Iris-virginica
138 6.4 3.1 5.5 1.8 Iris-virginica
139 6.0 3.0 4.8 1.8 Iris-virginica
140 6.9 3.1 5.4 2.1 Iris-virginica
141 6.7 3.1 5.6 2.4 Iris-virginica
142 6.9 3.1 5.1 2.3 Iris-virginica
143 5.8 2.7 5.1 1.9 Iris-virginica
144 6.8 3.2 5.9 2.3 Iris-virginica
145 6.7 3.3 5.7 2.5 Iris-virginica
146 6.7 3.0 5.2 2.3 Iris-virginica
147 6.3 2.5 5.0 1.9 Iris-virginica
148 6.5 3.0 5.2 2.0 Iris-virginica
149 6.2 3.4 5.4 2.3 Iris-virginica
150 5.9 3.0 5.1 1.8 Iris-virginica

View File

@@ -0,0 +1,77 @@
use csv::ReaderBuilder;
use std::io::Cursor;
use std::str::FromStr;
const IRIS_DATA: &str = include_str!("iris.csv");
#[derive(Eq, PartialEq, Debug)]
pub enum IrisType {
Setosa,
Versicolor,
Virginica,
}
impl FromStr for IrisType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Iris-virginica" => Ok(IrisType::Virginica),
"Iris-versicolor" => Ok(IrisType::Versicolor),
"Iris-setosa" => Ok(IrisType::Setosa),
_ => Err(String::from(s)),
}
}
}
#[derive(PartialEq, Debug)]
pub struct Iris {
pub class: IrisType,
pub petal_length: f32,
pub petal_width: f32,
pub sepal_length: f32,
pub sepal_width: f32,
}
pub fn import() -> Vec<Iris> {
let mut reader = ReaderBuilder::new()
.has_headers(false)
.from_reader(Cursor::new(IRIS_DATA));
let mut output = Vec::new();
for record in reader.records() {
let record = record.unwrap();
let petal_length = f32::from_str(&record[0]).unwrap();
let petal_width = f32::from_str(&record[1]).unwrap();
let sepal_length = f32::from_str(&record[2]).unwrap();
let sepal_width = f32::from_str(&record[3]).unwrap();
let class = IrisType::from_str(&record[4]).unwrap();
output.push(Iris {
class,
petal_length,
petal_width,
sepal_length,
sepal_width,
});
}
output
}
pub(crate) const EXPECTED_FIRST: Iris = Iris {
class: IrisType::Setosa,
petal_length: 5.1,
petal_width: 3.5,
sepal_length: 1.4,
sepal_width: 0.2,
};
#[cfg(test)]
mod test {
use crate::iris::import;
#[test]
fn first_element() {
let irises = import();
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST);
}
}

View File

@@ -1,71 +1,14 @@
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
use little_learner::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
use crate::rms_example::rms_example;
use little_learner::gradient_descent::gradient_descent;
use little_learner::hyper;
use little_learner::loss::predict_plane;
use little_learner::not_nan::{to_not_nan_1, to_not_nan_2};
use little_learner::predictor;
use little_learner::scalar::Scalar;
use little_learner::traits::Zero;
use ordered_float::NotNan;
const PLANE_XS: [[f64; 2]; 6] = [
[1.0, 2.05],
[1.0, 3.0],
[2.0, 2.0],
[2.0, 3.91],
[3.0, 6.13],
[4.0, 8.09],
];
const PLANE_YS: [f64; 6] = [13.99, 15.99, 18.0, 22.4, 30.2, 37.94];
mod iris;
mod rms_example;
fn main() {
let beta = NotNan::new(0.9).expect("not nan");
let stabilizer = NotNan::new(0.000_000_01).expect("not nan");
let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.01).expect("not nan"), 3000)
.with_stabilizer(stabilizer)
.with_beta(beta);
rms_example();
let iterated = {
let xs = to_not_nan_2(PLANE_XS);
let ys = to_not_nan_1(PLANE_YS);
let zero_params = [
RankedDifferentiable::of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()])
.to_unranked(),
Differentiable::of_scalar(Scalar::zero()),
];
gradient_descent(
hyper,
&xs,
RankedDifferentiableTagged::of_slice_2::<_, 2>,
&ys,
zero_params,
predictor::rms(predict_plane),
hyper::RmsGradientDescent::to_immutable,
)
};
let [theta0, theta1] = iterated;
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
let fitted_theta0 = theta0
.collect()
.iter()
.map(|x| x.into_inner())
.collect::<Vec<_>>();
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
assert_eq!(
fitted_theta0,
[3.974_645_444_172_085, 1.971_454_922_077_495]
);
assert_eq!(fitted_theta1, 6.164_579_048_274_036);
let irises = iris::import();
assert_eq!(irises[0], crate::iris::EXPECTED_FIRST);
}
#[cfg(test)]
mod tests {}

View File

@@ -0,0 +1,65 @@
use little_learner::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
use little_learner::gradient_descent::gradient_descent;
use little_learner::hyper;
use little_learner::loss::predict_plane;
use little_learner::not_nan::{to_not_nan_1, to_not_nan_2};
use little_learner::predictor;
use little_learner::scalar::Scalar;
use little_learner::traits::Zero;
use ordered_float::NotNan;
const PLANE_XS: [[f64; 2]; 6] = [
[1.0, 2.05],
[1.0, 3.0],
[2.0, 2.0],
[2.0, 3.91],
[3.0, 6.13],
[4.0, 8.09],
];
const PLANE_YS: [f64; 6] = [13.99, 15.99, 18.0, 22.4, 30.2, 37.94];
pub(crate) fn rms_example() {
let beta = NotNan::new(0.9).expect("not nan");
let stabilizer = NotNan::new(0.000_000_01).expect("not nan");
let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.01).expect("not nan"), 3000)
.with_stabilizer(stabilizer)
.with_beta(beta);
let iterated = {
let xs = to_not_nan_2(PLANE_XS);
let ys = to_not_nan_1(PLANE_YS);
let zero_params = [
RankedDifferentiable::of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()])
.to_unranked(),
Differentiable::of_scalar(Scalar::zero()),
];
gradient_descent(
hyper,
&xs,
RankedDifferentiableTagged::of_slice_2::<_, 2>,
&ys,
zero_params,
predictor::rms(predict_plane),
hyper::RmsGradientDescent::to_immutable,
)
};
let [theta0, theta1] = iterated;
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
let fitted_theta0 = theta0
.collect()
.iter()
.map(|x| x.into_inner())
.collect::<Vec<_>>();
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
assert_eq!(
fitted_theta0,
[3.974_645_444_172_085, 1.971_454_922_077_495]
);
assert_eq!(fitted_theta1, 6.164_579_048_274_036);
}