Iris data (#29)

This commit is contained in:
Patrick Stevens
2023-06-17 19:03:01 +01:00
committed by GitHub
parent 095a8af7f2
commit fd55cd1c5f
11 changed files with 405 additions and 68 deletions

View File

@@ -0,0 +1,50 @@
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
use crate::ext::relu;
use crate::traits::NumLike;
pub struct Block<F, const N: usize> {
f: F,
ranks: [usize; N],
}
pub fn compose<'a, A, T, B, C, F, G, const N: usize, const M: usize>(
b1: Block<F, N>,
b2: Block<G, M>,
j: usize,
) -> Block<impl FnOnce(A, &'a [T]) -> C, { N + M }>
where
F: FnOnce(A, &'a [T]) -> B,
G: FnOnce(B, &'a [T]) -> C,
T: 'a,
{
let mut ranks = [0usize; N + M];
ranks.copy_from_slice(&b1.ranks[..N]);
ranks[N..(M + N)].copy_from_slice(&b2.ranks[..M]);
Block {
f: move |t, theta| (b2.f)((b1.f)(t, theta), &theta[j..]),
ranks,
}
}
pub fn dense<'a, 'b, A, Tag>(
input_len: usize,
neuron_count: usize,
) -> Block<
impl FnOnce(&'a RankedDifferentiableTagged<A, Tag, 1>, &'b [Differentiable<A>]) -> Differentiable<A>,
2,
>
where
Tag: Clone,
A: NumLike + PartialOrd + Default,
{
Block {
f: |t, theta: &'b [Differentiable<A>]| -> Differentiable<A> {
relu(
t,
&(theta[0].clone().attach_rank().unwrap()),
&(theta[1].clone().attach_rank().unwrap()),
)
},
ranks: [input_len, neuron_count],
}
}

View File

@@ -430,7 +430,7 @@ mod tests {
crate::decider::relu(
&t,
&RankedDifferentiable::of_slice(weights),
Scalar::make(bias.clone()),
Scalar::make(*bias),
)
.clone_real_part()
.into_inner(),

View File

@@ -404,8 +404,11 @@ mod tests {
.map(|x| x.into_inner())
.collect::<Vec<_>>();
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
assert_eq!(fitted_theta0, [3.9746454441720851, 1.9714549220774951]);
assert_eq!(fitted_theta1, 6.1645790482740361);
assert_eq!(
fitted_theta0,
[3.974_645_444_172_085, 1.971_454_922_077_495]
);
assert_eq!(fitted_theta1, 6.164_579_048_274_036);
}
#[test]
@@ -452,7 +455,7 @@ mod tests {
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
assert_eq!(
fitted_theta0,
[3.980_262_420_345_729_5, 1.977_071_898_301_443_9]
[3.980_262_420_345_729_5, 1.977_071_898_301_444]
);
assert_eq!(fitted_theta1, 6.170_196_024_282_712_5);
}

View File

@@ -3,6 +3,7 @@
#![feature(array_methods)]
pub mod auto_diff;
pub mod block;
pub mod decider;
pub mod ext;
pub mod gradient_descent;

View File

@@ -106,7 +106,7 @@ mod test_smooth {
]
.map(|x| hydrate(&x));
let mut current = hydrate(&vec![0.8, 3.1, 2.2]);
let mut current = hydrate(&[0.8, 3.1, 2.2]);
let mut output = Vec::with_capacity(inputs.len());
for input in inputs {
current = smooth(decay.clone(), &current, &input);