Iris data (#29)
This commit is contained in:
50
little_learner/src/block.rs
Normal file
50
little_learner/src/block.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
|
||||
use crate::ext::relu;
|
||||
use crate::traits::NumLike;
|
||||
|
||||
pub struct Block<F, const N: usize> {
|
||||
f: F,
|
||||
ranks: [usize; N],
|
||||
}
|
||||
|
||||
pub fn compose<'a, A, T, B, C, F, G, const N: usize, const M: usize>(
|
||||
b1: Block<F, N>,
|
||||
b2: Block<G, M>,
|
||||
j: usize,
|
||||
) -> Block<impl FnOnce(A, &'a [T]) -> C, { N + M }>
|
||||
where
|
||||
F: FnOnce(A, &'a [T]) -> B,
|
||||
G: FnOnce(B, &'a [T]) -> C,
|
||||
T: 'a,
|
||||
{
|
||||
let mut ranks = [0usize; N + M];
|
||||
ranks.copy_from_slice(&b1.ranks[..N]);
|
||||
ranks[N..(M + N)].copy_from_slice(&b2.ranks[..M]);
|
||||
Block {
|
||||
f: move |t, theta| (b2.f)((b1.f)(t, theta), &theta[j..]),
|
||||
ranks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dense<'a, 'b, A, Tag>(
|
||||
input_len: usize,
|
||||
neuron_count: usize,
|
||||
) -> Block<
|
||||
impl FnOnce(&'a RankedDifferentiableTagged<A, Tag, 1>, &'b [Differentiable<A>]) -> Differentiable<A>,
|
||||
2,
|
||||
>
|
||||
where
|
||||
Tag: Clone,
|
||||
A: NumLike + PartialOrd + Default,
|
||||
{
|
||||
Block {
|
||||
f: |t, theta: &'b [Differentiable<A>]| -> Differentiable<A> {
|
||||
relu(
|
||||
t,
|
||||
&(theta[0].clone().attach_rank().unwrap()),
|
||||
&(theta[1].clone().attach_rank().unwrap()),
|
||||
)
|
||||
},
|
||||
ranks: [input_len, neuron_count],
|
||||
}
|
||||
}
|
@@ -430,7 +430,7 @@ mod tests {
|
||||
crate::decider::relu(
|
||||
&t,
|
||||
&RankedDifferentiable::of_slice(weights),
|
||||
Scalar::make(bias.clone()),
|
||||
Scalar::make(*bias),
|
||||
)
|
||||
.clone_real_part()
|
||||
.into_inner(),
|
||||
|
@@ -404,8 +404,11 @@ mod tests {
|
||||
.map(|x| x.into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(fitted_theta0, [3.9746454441720851, 1.9714549220774951]);
|
||||
assert_eq!(fitted_theta1, 6.1645790482740361);
|
||||
assert_eq!(
|
||||
fitted_theta0,
|
||||
[3.974_645_444_172_085, 1.971_454_922_077_495]
|
||||
);
|
||||
assert_eq!(fitted_theta1, 6.164_579_048_274_036);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -452,7 +455,7 @@ mod tests {
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(
|
||||
fitted_theta0,
|
||||
[3.980_262_420_345_729_5, 1.977_071_898_301_443_9]
|
||||
[3.980_262_420_345_729_5, 1.977_071_898_301_444]
|
||||
);
|
||||
assert_eq!(fitted_theta1, 6.170_196_024_282_712_5);
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
#![feature(array_methods)]
|
||||
|
||||
pub mod auto_diff;
|
||||
pub mod block;
|
||||
pub mod decider;
|
||||
pub mod ext;
|
||||
pub mod gradient_descent;
|
||||
|
@@ -106,7 +106,7 @@ mod test_smooth {
|
||||
]
|
||||
.map(|x| hydrate(&x));
|
||||
|
||||
let mut current = hydrate(&vec![0.8, 3.1, 2.2]);
|
||||
let mut current = hydrate(&[0.8, 3.1, 2.2]);
|
||||
let mut output = Vec::with_capacity(inputs.len());
|
||||
for input in inputs {
|
||||
current = smooth(decay.clone(), ¤t, &input);
|
||||
|
Reference in New Issue
Block a user