Baby's first layer (#25)
This commit is contained in:
@@ -69,7 +69,7 @@ where
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum DifferentiableContents<A, Tag> {
|
enum DifferentiableContents<A, Tag> {
|
||||||
Scalar(Scalar<A>, Tag),
|
Scalar(Scalar<A>, Tag),
|
||||||
// Contains the rank.
|
// Contains the rank of this differentiable (i.e. one more than the rank of the inputs).
|
||||||
Vector(Vec<DifferentiableTagged<A, Tag>>, usize),
|
Vector(Vec<DifferentiableTagged<A, Tag>>, usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,6 +199,64 @@ impl<A, Tag> DifferentiableContents<A, Tag> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Unwraps one layer of each input, so the passed function takes inputs which have decreased
|
||||||
|
/// the ranks of the `map2_once_tagged` input by one.
|
||||||
|
/// Panics if passed a scalar or if the input vectors are not the same length.
|
||||||
|
pub fn map2_once_tagged<B, C, Tag2, Tag3, F>(
|
||||||
|
self: &DifferentiableContents<A, Tag>,
|
||||||
|
other: &DifferentiableContents<B, Tag2>,
|
||||||
|
mut f: F,
|
||||||
|
) -> DifferentiableContents<C, Tag3>
|
||||||
|
where
|
||||||
|
F: FnMut(
|
||||||
|
&DifferentiableTagged<A, Tag>,
|
||||||
|
&DifferentiableTagged<B, Tag2>,
|
||||||
|
) -> DifferentiableTagged<C, Tag3>,
|
||||||
|
{
|
||||||
|
match (self, other) {
|
||||||
|
(DifferentiableContents::Scalar(_, _), _) => {
|
||||||
|
panic!("First arg needed to have non-scalar rank")
|
||||||
|
}
|
||||||
|
(_, DifferentiableContents::Scalar(_, _)) => {
|
||||||
|
panic!("Second arg needed to have non-scalar rank")
|
||||||
|
}
|
||||||
|
(
|
||||||
|
DifferentiableContents::Vector(v1, rank1),
|
||||||
|
DifferentiableContents::Vector(v2, _rank2),
|
||||||
|
) => {
|
||||||
|
assert_eq!(
|
||||||
|
v1.len(),
|
||||||
|
v2.len(),
|
||||||
|
"Must map two vectors of the same length, got {rank1} and {_rank2}"
|
||||||
|
);
|
||||||
|
assert_ne!(
|
||||||
|
v1.len(),
|
||||||
|
0,
|
||||||
|
"Cannot determine a rank of a zero-length vector"
|
||||||
|
);
|
||||||
|
let mut rank = 0usize;
|
||||||
|
DifferentiableContents::Vector(
|
||||||
|
v1.iter()
|
||||||
|
.zip(v2.iter())
|
||||||
|
.map(|(a, b)| {
|
||||||
|
let result = f(a, b);
|
||||||
|
match result.contents {
|
||||||
|
DifferentiableContents::Vector(_, discovered_rank) => {
|
||||||
|
rank = discovered_rank + 1;
|
||||||
|
}
|
||||||
|
DifferentiableContents::Scalar(_, _) => {
|
||||||
|
rank = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
rank,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn of_slice<'a, T, I>(tag: Tag, input: I) -> DifferentiableContents<T, Tag>
|
fn of_slice<'a, T, I>(tag: Tag, input: I) -> DifferentiableContents<T, Tag>
|
||||||
where
|
where
|
||||||
T: Clone + 'a,
|
T: Clone + 'a,
|
||||||
@@ -277,6 +335,22 @@ impl<A, Tag> DifferentiableTagged<A, Tag> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn map2_once_tagged<B, C, Tag2, Tag3, F>(
|
||||||
|
self: &DifferentiableTagged<A, Tag>,
|
||||||
|
other: &DifferentiableTagged<B, Tag2>,
|
||||||
|
f: F,
|
||||||
|
) -> DifferentiableTagged<C, Tag3>
|
||||||
|
where
|
||||||
|
F: FnMut(
|
||||||
|
&DifferentiableTagged<A, Tag>,
|
||||||
|
&DifferentiableTagged<B, Tag2>,
|
||||||
|
) -> DifferentiableTagged<C, Tag3>,
|
||||||
|
{
|
||||||
|
DifferentiableTagged {
|
||||||
|
contents: self.contents.map2_once_tagged(&other.contents, f),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn attach_rank<const RANK: usize>(
|
pub fn attach_rank<const RANK: usize>(
|
||||||
self: DifferentiableTagged<A, Tag>,
|
self: DifferentiableTagged<A, Tag>,
|
||||||
) -> Option<RankedDifferentiableTagged<A, Tag, RANK>> {
|
) -> Option<RankedDifferentiableTagged<A, Tag, RANK>> {
|
||||||
@@ -582,10 +656,10 @@ impl<A, Tag, const RANK: usize> RankedDifferentiableTagged<A, Tag, RANK> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn map2_tagged<B, C, Tag2, Tag3, F>(
|
pub fn map2_tagged<'a, 'b, B, C, Tag2, Tag3, F>(
|
||||||
self: &RankedDifferentiableTagged<A, Tag, RANK>,
|
self: &'a RankedDifferentiableTagged<A, Tag, RANK>,
|
||||||
other: &RankedDifferentiableTagged<B, Tag2, RANK>,
|
other: &'a RankedDifferentiableTagged<B, Tag2, RANK>,
|
||||||
f: &mut F,
|
f: &'b mut F,
|
||||||
) -> RankedDifferentiableTagged<C, Tag3, RANK>
|
) -> RankedDifferentiableTagged<C, Tag3, RANK>
|
||||||
where
|
where
|
||||||
F: FnMut(&Scalar<A>, Tag, &Scalar<B>, Tag2) -> (Scalar<C>, Tag3),
|
F: FnMut(&Scalar<A>, Tag, &Scalar<B>, Tag2) -> (Scalar<C>, Tag3),
|
||||||
@@ -598,6 +672,44 @@ impl<A, Tag, const RANK: usize> RankedDifferentiableTagged<A, Tag, RANK> {
|
|||||||
contents: DifferentiableTagged::map2_tagged(&self.contents, &other.contents, f),
|
contents: DifferentiableTagged::map2_tagged(&self.contents, &other.contents, f),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub fn map2_once_tagged<
|
||||||
|
'a,
|
||||||
|
'c,
|
||||||
|
B,
|
||||||
|
C: 'a,
|
||||||
|
Tag2,
|
||||||
|
Tag3: 'a,
|
||||||
|
F,
|
||||||
|
const RANK_B: usize,
|
||||||
|
const RANK_OUT: usize,
|
||||||
|
>(
|
||||||
|
self: &'a RankedDifferentiableTagged<A, Tag, RANK>,
|
||||||
|
other: &'a RankedDifferentiableTagged<B, Tag2, RANK_B>,
|
||||||
|
f: &'c mut F,
|
||||||
|
) -> RankedDifferentiableTagged<C, Tag3, RANK_OUT>
|
||||||
|
where
|
||||||
|
F: FnMut(
|
||||||
|
&RankedDifferentiableTagged<A, Tag, { RANK - 1 }>,
|
||||||
|
&RankedDifferentiableTagged<B, Tag2, { RANK_B - 1 }>,
|
||||||
|
) -> RankedDifferentiableTagged<C, Tag3, { RANK_OUT - 1 }>,
|
||||||
|
A: Clone,
|
||||||
|
B: Clone,
|
||||||
|
Tag: Clone,
|
||||||
|
Tag2: Clone,
|
||||||
|
'c: 'a,
|
||||||
|
{
|
||||||
|
RankedDifferentiableTagged {
|
||||||
|
contents: DifferentiableTagged::map2_once_tagged(
|
||||||
|
&self.contents,
|
||||||
|
&other.contents,
|
||||||
|
&mut |a: &DifferentiableTagged<A, Tag>, b: &DifferentiableTagged<B, Tag2>| {
|
||||||
|
let a = (*a).clone().attach_rank::<{ RANK - 1 }>().unwrap();
|
||||||
|
let b = (*b).clone().attach_rank::<{ RANK_B - 1 }>().unwrap();
|
||||||
|
f(&a, &b).to_unranked()
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_vector(
|
pub fn to_vector(
|
||||||
self: RankedDifferentiableTagged<A, Tag, RANK>,
|
self: RankedDifferentiableTagged<A, Tag, RANK>,
|
||||||
@@ -634,6 +746,22 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
|
|||||||
{
|
{
|
||||||
self.map2_tagged(other, &mut |a, (), b, ()| (f(a, b), ()))
|
self.map2_tagged(other, &mut |a, (), b, ()| (f(a, b), ()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn map2_once<B, C, F, const RANK_B: usize, const RANK_OUT: usize>(
|
||||||
|
self: &RankedDifferentiable<A, RANK>,
|
||||||
|
other: &RankedDifferentiable<B, RANK_B>,
|
||||||
|
f: &mut F,
|
||||||
|
) -> RankedDifferentiable<C, RANK_OUT>
|
||||||
|
where
|
||||||
|
F: FnMut(
|
||||||
|
&RankedDifferentiable<A, { RANK - 1 }>,
|
||||||
|
&RankedDifferentiable<B, { RANK_B - 1 }>,
|
||||||
|
) -> RankedDifferentiable<C, { RANK_OUT - 1 }>,
|
||||||
|
A: Clone,
|
||||||
|
B: Clone,
|
||||||
|
{
|
||||||
|
self.map2_once_tagged(other, f)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn grad<A, Tag, F, const RANK: usize, const PARAM_RANK: usize>(
|
pub fn grad<A, Tag, F, const RANK: usize, const PARAM_RANK: usize>(
|
||||||
|
@@ -15,19 +15,19 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn linear<A, Tag1, Tag2>(
|
fn linear<A, Tag1, Tag2>(
|
||||||
t: RankedDifferentiableTagged<A, Tag1, 1>,
|
t: &RankedDifferentiableTagged<A, Tag1, 1>,
|
||||||
theta0: RankedDifferentiableTagged<A, Tag2, 1>,
|
theta0: &RankedDifferentiableTagged<A, Tag2, 1>,
|
||||||
theta1: Scalar<A>,
|
theta1: Scalar<A>,
|
||||||
) -> Scalar<A>
|
) -> Scalar<A>
|
||||||
where
|
where
|
||||||
A: NumLike,
|
A: NumLike,
|
||||||
{
|
{
|
||||||
dot(&theta0, &t) + theta1
|
dot(theta0, t) + theta1
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn relu<A, Tag1, Tag2>(
|
pub fn relu<A, Tag1, Tag2>(
|
||||||
t: RankedDifferentiableTagged<A, Tag1, 1>,
|
t: &RankedDifferentiableTagged<A, Tag1, 1>,
|
||||||
theta0: RankedDifferentiableTagged<A, Tag2, 1>,
|
theta0: &RankedDifferentiableTagged<A, Tag2, 1>,
|
||||||
theta1: Scalar<A>,
|
theta1: Scalar<A>,
|
||||||
) -> Scalar<A>
|
) -> Scalar<A>
|
||||||
where
|
where
|
||||||
@@ -50,7 +50,7 @@ mod test_decider {
|
|||||||
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
|
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
|
||||||
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
|
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
|
||||||
|
|
||||||
let result = linear(t, theta0, theta1).real_part().into_inner();
|
let result = linear(&t, &theta0, theta1).real_part().into_inner();
|
||||||
|
|
||||||
assert!((result + 0.1).abs() < 0.000_000_01);
|
assert!((result + 0.1).abs() < 0.000_000_01);
|
||||||
}
|
}
|
||||||
@@ -61,7 +61,7 @@ mod test_decider {
|
|||||||
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
|
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
|
||||||
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
|
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
|
||||||
|
|
||||||
let result = relu(t, theta0, theta1).real_part().into_inner();
|
let result = relu(&t, &theta0, theta1).real_part().into_inner();
|
||||||
|
|
||||||
assert_eq!(result, 0.0);
|
assert_eq!(result, 0.0);
|
||||||
}
|
}
|
||||||
|
75
little_learner/src/layer.rs
Normal file
75
little_learner/src/layer.rs
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
use crate::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
|
||||||
|
use crate::decider::relu;
|
||||||
|
use crate::traits::NumLike;
|
||||||
|
|
||||||
|
/// Returns a tensor1.
|
||||||
|
/// Theta has two components: a tensor2 of weights and a tensor1 of bias.
|
||||||
|
pub fn layer<T>(
|
||||||
|
theta: Differentiable<T>,
|
||||||
|
t: RankedDifferentiable<T, 1>,
|
||||||
|
) -> RankedDifferentiable<T, 1>
|
||||||
|
where
|
||||||
|
T: NumLike + PartialOrd,
|
||||||
|
{
|
||||||
|
let mut theta = theta.into_vector();
|
||||||
|
assert_eq!(theta.len(), 2, "Needed weights and a bias");
|
||||||
|
let b = theta.pop().unwrap().attach_rank::<1>().unwrap();
|
||||||
|
let w = theta.pop().unwrap().attach_rank::<2>().unwrap();
|
||||||
|
|
||||||
|
RankedDifferentiableTagged::map2_once(
|
||||||
|
&w,
|
||||||
|
&b,
|
||||||
|
&mut |w: &RankedDifferentiable<_, 1>, b: &RankedDifferentiable<_, 0>| {
|
||||||
|
RankedDifferentiableTagged::of_scalar(relu(&t, w, b.clone().to_scalar()))
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::auto_diff::{Differentiable, RankedDifferentiable};
|
||||||
|
use crate::layer::layer;
|
||||||
|
use crate::not_nan::{to_not_nan_1, to_not_nan_2};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_single_layer() {
|
||||||
|
let b = RankedDifferentiable::of_slice(&to_not_nan_1([1.0, 2.0]));
|
||||||
|
let w = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
|
||||||
|
[3.0, 4.0, 5.0],
|
||||||
|
[6.0, 7.0, 8.0],
|
||||||
|
]));
|
||||||
|
let theta = Differentiable::of_vec(vec![w.to_unranked(), b.to_unranked()]);
|
||||||
|
|
||||||
|
/*
|
||||||
|
Two neurons:
|
||||||
|
w =
|
||||||
|
(3 4 5
|
||||||
|
6 7 8)
|
||||||
|
b = (1, 2)
|
||||||
|
|
||||||
|
Three inputs:
|
||||||
|
t = (9, 10, 11)
|
||||||
|
|
||||||
|
Output has two elements, one per neuron.
|
||||||
|
Neuron 1 has weights (3,4,5) and bias 1;
|
||||||
|
Neuron 2 has weights (6,7,8) and bias 2.
|
||||||
|
|
||||||
|
Neuron 1 is relu(t, (3,4,5), 1), which is (9, 10, 11).(3, 4, 5) + 1.
|
||||||
|
Neuron 2 is relu(t, (6,7,8), 2), which is (9, 10, 11).(6, 7, 8) + 2.
|
||||||
|
*/
|
||||||
|
|
||||||
|
let t = RankedDifferentiable::of_slice(&to_not_nan_1([9.0, 10.0, 11.0]));
|
||||||
|
let mut output = layer(theta, t)
|
||||||
|
.to_vector()
|
||||||
|
.iter()
|
||||||
|
.map(|t| (*t).clone().to_scalar().clone_real_part().into_inner())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
assert_eq!(output.len(), 2);
|
||||||
|
let result_2 = output.pop().unwrap();
|
||||||
|
let result_1 = output.pop().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(result_1, (9 * 3 + 10 * 4 + 11 * 5 + 1) as f64);
|
||||||
|
assert_eq!(result_2, (9 * 6 + 10 * 7 + 11 * 8 + 2) as f64);
|
||||||
|
}
|
||||||
|
}
|
@@ -6,6 +6,7 @@ pub mod auto_diff;
|
|||||||
pub mod decider;
|
pub mod decider;
|
||||||
pub mod gradient_descent;
|
pub mod gradient_descent;
|
||||||
pub mod hyper;
|
pub mod hyper;
|
||||||
|
pub mod layer;
|
||||||
pub mod loss;
|
pub mod loss;
|
||||||
pub mod not_nan;
|
pub mod not_nan;
|
||||||
pub mod predictor;
|
pub mod predictor;
|
||||||
|
@@ -7,6 +7,10 @@ where
|
|||||||
xs.map(|x| NotNan::new(x).expect("not nan"))
|
xs.map(|x| NotNan::new(x).expect("not nan"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_not_nan_1<T, const N: usize>(xs: [NotNan<T>; N]) -> [T; N] {
|
||||||
|
xs.map(|x| x.into_inner())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_not_nan_2<T, const N: usize, const M: usize>(xs: [[T; N]; M]) -> [[NotNan<T>; N]; M]
|
pub fn to_not_nan_2<T, const N: usize, const M: usize>(xs: [[T; N]; M]) -> [[NotNan<T>; N]; M]
|
||||||
where
|
where
|
||||||
T: ordered_float::Float,
|
T: ordered_float::Float,
|
||||||
|
Reference in New Issue
Block a user