diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs index 539362c..cae7c0b 100644 --- a/little_learner/src/auto_diff.rs +++ b/little_learner/src/auto_diff.rs @@ -69,7 +69,7 @@ where #[derive(Debug)] enum DifferentiableContents { Scalar(Scalar, Tag), - // Contains the rank. + // Contains the rank of this differentiable (i.e. one more than the rank of the inputs). Vector(Vec>, usize), } @@ -199,6 +199,64 @@ impl DifferentiableContents { } } + /// Unwraps one layer of each input, so the passed function takes inputs which have decreased + /// the ranks of the `map2_once_tagged` input by one. + /// Panics if passed a scalar or if the input vectors are not the same length. + pub fn map2_once_tagged( + self: &DifferentiableContents, + other: &DifferentiableContents, + mut f: F, + ) -> DifferentiableContents + where + F: FnMut( + &DifferentiableTagged, + &DifferentiableTagged, + ) -> DifferentiableTagged, + { + match (self, other) { + (DifferentiableContents::Scalar(_, _), _) => { + panic!("First arg needed to have non-scalar rank") + } + (_, DifferentiableContents::Scalar(_, _)) => { + panic!("Second arg needed to have non-scalar rank") + } + ( + DifferentiableContents::Vector(v1, rank1), + DifferentiableContents::Vector(v2, _rank2), + ) => { + assert_eq!( + v1.len(), + v2.len(), + "Must map two vectors of the same length, got {rank1} and {_rank2}" + ); + assert_ne!( + v1.len(), + 0, + "Cannot determine a rank of a zero-length vector" + ); + let mut rank = 0usize; + DifferentiableContents::Vector( + v1.iter() + .zip(v2.iter()) + .map(|(a, b)| { + let result = f(a, b); + match result.contents { + DifferentiableContents::Vector(_, discovered_rank) => { + rank = discovered_rank + 1; + } + DifferentiableContents::Scalar(_, _) => { + rank = 1; + } + } + result + }) + .collect(), + rank, + ) + } + } + } + fn of_slice<'a, T, I>(tag: Tag, input: I) -> DifferentiableContents where T: Clone + 'a, @@ -277,6 +335,22 @@ impl DifferentiableTagged { } } + pub fn map2_once_tagged( + self: &DifferentiableTagged, + other: &DifferentiableTagged, + f: F, + ) -> DifferentiableTagged + where + F: FnMut( + &DifferentiableTagged, + &DifferentiableTagged, + ) -> DifferentiableTagged, + { + DifferentiableTagged { + contents: self.contents.map2_once_tagged(&other.contents, f), + } + } + pub fn attach_rank( self: DifferentiableTagged, ) -> Option> { @@ -582,10 +656,10 @@ impl RankedDifferentiableTagged { } } - pub fn map2_tagged( - self: &RankedDifferentiableTagged, - other: &RankedDifferentiableTagged, - f: &mut F, + pub fn map2_tagged<'a, 'b, B, C, Tag2, Tag3, F>( + self: &'a RankedDifferentiableTagged, + other: &'a RankedDifferentiableTagged, + f: &'b mut F, ) -> RankedDifferentiableTagged where F: FnMut(&Scalar, Tag, &Scalar, Tag2) -> (Scalar, Tag3), @@ -598,6 +672,44 @@ impl RankedDifferentiableTagged { contents: DifferentiableTagged::map2_tagged(&self.contents, &other.contents, f), } } + pub fn map2_once_tagged< + 'a, + 'c, + B, + C: 'a, + Tag2, + Tag3: 'a, + F, + const RANK_B: usize, + const RANK_OUT: usize, + >( + self: &'a RankedDifferentiableTagged, + other: &'a RankedDifferentiableTagged, + f: &'c mut F, + ) -> RankedDifferentiableTagged + where + F: FnMut( + &RankedDifferentiableTagged, + &RankedDifferentiableTagged, + ) -> RankedDifferentiableTagged, + A: Clone, + B: Clone, + Tag: Clone, + Tag2: Clone, + 'c: 'a, + { + RankedDifferentiableTagged { + contents: DifferentiableTagged::map2_once_tagged( + &self.contents, + &other.contents, + &mut |a: &DifferentiableTagged, b: &DifferentiableTagged| { + let a = (*a).clone().attach_rank::<{ RANK - 1 }>().unwrap(); + let b = (*b).clone().attach_rank::<{ RANK_B - 1 }>().unwrap(); + f(&a, &b).to_unranked() + }, + ), + } + } pub fn to_vector( self: RankedDifferentiableTagged, @@ -634,6 +746,22 @@ impl RankedDifferentiable { { self.map2_tagged(other, &mut |a, (), b, ()| (f(a, b), ())) } + + pub fn map2_once( + self: &RankedDifferentiable, + other: &RankedDifferentiable, + f: &mut F, + ) -> RankedDifferentiable + where + F: FnMut( + &RankedDifferentiable, + &RankedDifferentiable, + ) -> RankedDifferentiable, + A: Clone, + B: Clone, + { + self.map2_once_tagged(other, f) + } } pub fn grad( diff --git a/little_learner/src/decider.rs b/little_learner/src/decider.rs index 9dcbe6c..c979fb9 100644 --- a/little_learner/src/decider.rs +++ b/little_learner/src/decider.rs @@ -15,19 +15,19 @@ where } fn linear( - t: RankedDifferentiableTagged, - theta0: RankedDifferentiableTagged, + t: &RankedDifferentiableTagged, + theta0: &RankedDifferentiableTagged, theta1: Scalar, ) -> Scalar where A: NumLike, { - dot(&theta0, &t) + theta1 + dot(theta0, t) + theta1 } pub fn relu( - t: RankedDifferentiableTagged, - theta0: RankedDifferentiableTagged, + t: &RankedDifferentiableTagged, + theta0: &RankedDifferentiableTagged, theta1: Scalar, ) -> Scalar where @@ -50,7 +50,7 @@ mod test_decider { let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan")); let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0])); - let result = linear(t, theta0, theta1).real_part().into_inner(); + let result = linear(&t, &theta0, theta1).real_part().into_inner(); assert!((result + 0.1).abs() < 0.000_000_01); } @@ -61,7 +61,7 @@ mod test_decider { let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan")); let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0])); - let result = relu(t, theta0, theta1).real_part().into_inner(); + let result = relu(&t, &theta0, theta1).real_part().into_inner(); assert_eq!(result, 0.0); } diff --git a/little_learner/src/layer.rs b/little_learner/src/layer.rs new file mode 100644 index 0000000..18117e2 --- /dev/null +++ b/little_learner/src/layer.rs @@ -0,0 +1,75 @@ +use crate::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged}; +use crate::decider::relu; +use crate::traits::NumLike; + +/// Returns a tensor1. +/// Theta has two components: a tensor2 of weights and a tensor1 of bias. +pub fn layer( + theta: Differentiable, + t: RankedDifferentiable, +) -> RankedDifferentiable +where + T: NumLike + PartialOrd, +{ + let mut theta = theta.into_vector(); + assert_eq!(theta.len(), 2, "Needed weights and a bias"); + let b = theta.pop().unwrap().attach_rank::<1>().unwrap(); + let w = theta.pop().unwrap().attach_rank::<2>().unwrap(); + + RankedDifferentiableTagged::map2_once( + &w, + &b, + &mut |w: &RankedDifferentiable<_, 1>, b: &RankedDifferentiable<_, 0>| { + RankedDifferentiableTagged::of_scalar(relu(&t, w, b.clone().to_scalar())) + }, + ) +} + +#[cfg(test)] +mod tests { + use crate::auto_diff::{Differentiable, RankedDifferentiable}; + use crate::layer::layer; + use crate::not_nan::{to_not_nan_1, to_not_nan_2}; + + #[test] + fn test_single_layer() { + let b = RankedDifferentiable::of_slice(&to_not_nan_1([1.0, 2.0])); + let w = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([ + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + ])); + let theta = Differentiable::of_vec(vec![w.to_unranked(), b.to_unranked()]); + + /* + Two neurons: + w = + (3 4 5 + 6 7 8) + b = (1, 2) + + Three inputs: + t = (9, 10, 11) + + Output has two elements, one per neuron. + Neuron 1 has weights (3,4,5) and bias 1; + Neuron 2 has weights (6,7,8) and bias 2. + + Neuron 1 is relu(t, (3,4,5), 1), which is (9, 10, 11).(3, 4, 5) + 1. + Neuron 2 is relu(t, (6,7,8), 2), which is (9, 10, 11).(6, 7, 8) + 2. + */ + + let t = RankedDifferentiable::of_slice(&to_not_nan_1([9.0, 10.0, 11.0])); + let mut output = layer(theta, t) + .to_vector() + .iter() + .map(|t| (*t).clone().to_scalar().clone_real_part().into_inner()) + .collect::>(); + + assert_eq!(output.len(), 2); + let result_2 = output.pop().unwrap(); + let result_1 = output.pop().unwrap(); + + assert_eq!(result_1, (9 * 3 + 10 * 4 + 11 * 5 + 1) as f64); + assert_eq!(result_2, (9 * 6 + 10 * 7 + 11 * 8 + 2) as f64); + } +} diff --git a/little_learner/src/lib.rs b/little_learner/src/lib.rs index 35daaf9..73e528b 100644 --- a/little_learner/src/lib.rs +++ b/little_learner/src/lib.rs @@ -6,6 +6,7 @@ pub mod auto_diff; pub mod decider; pub mod gradient_descent; pub mod hyper; +pub mod layer; pub mod loss; pub mod not_nan; pub mod predictor; diff --git a/little_learner/src/not_nan.rs b/little_learner/src/not_nan.rs index 02b698b..416235d 100644 --- a/little_learner/src/not_nan.rs +++ b/little_learner/src/not_nan.rs @@ -7,6 +7,10 @@ where xs.map(|x| NotNan::new(x).expect("not nan")) } +pub fn from_not_nan_1(xs: [NotNan; N]) -> [T; N] { + xs.map(|x| x.into_inner()) +} + pub fn to_not_nan_2(xs: [[T; N]; M]) -> [[NotNan; N]; M] where T: ordered_float::Float,