From 242f71fa758472789c5500ad31c55501a9b4809c Mon Sep 17 00:00:00 2001 From: Patrick Stevens Date: Wed, 14 Jun 2023 15:58:26 +0100 Subject: [PATCH] Ext2 (#26) * Ext1 * ext1 tests * ext2 --- little_learner/src/auto_diff.rs | 71 +++++++++ little_learner/src/ext.rs | 272 ++++++++++++++++++++++++++++++++ little_learner/src/lib.rs | 1 + 3 files changed, 344 insertions(+) create mode 100644 little_learner/src/ext.rs diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs index cae7c0b..8debc49 100644 --- a/little_learner/src/auto_diff.rs +++ b/little_learner/src/auto_diff.rs @@ -199,6 +199,34 @@ impl DifferentiableContents { } } + pub fn map_once_tagged( + self: &DifferentiableContents, + mut f: F, + ) -> DifferentiableContents + where + F: FnMut(&DifferentiableTagged) -> DifferentiableTagged, + { + match self { + DifferentiableContents::Scalar(_, _) => { + panic!("can't map_once_tagged into a scalar"); + } + DifferentiableContents::Vector(v, _rank) => { + assert_ne!(v.len(), 0, "Can't get rank of an empty vector"); + let mut rank = 0; + DifferentiableContents::Vector( + v.iter() + .map(|x| { + let result = f(x); + rank = result.rank(); + result + }) + .collect(), + rank + 1, + ) + } + } + } + /// Unwraps one layer of each input, so the passed function takes inputs which have decreased /// the ranks of the `map2_once_tagged` input by one. /// Panics if passed a scalar or if the input vectors are not the same length. @@ -335,6 +363,18 @@ impl DifferentiableTagged { } } + pub fn map_once_tagged( + self: &DifferentiableTagged, + f: F, + ) -> DifferentiableTagged + where + F: FnMut(&DifferentiableTagged) -> DifferentiableTagged, + { + DifferentiableTagged { + contents: self.contents.map_once_tagged(f), + } + } + pub fn map2_once_tagged( self: &DifferentiableTagged, other: &DifferentiableTagged, @@ -519,6 +559,37 @@ pub struct RankedDifferentiableTagged { contents: DifferentiableTagged, } +impl RankedDifferentiableTagged { + pub fn map_once_tagged( + &self, + f: &mut F, + ) -> DifferentiableTagged + where + A: Clone, + Tag: Clone, + B: Clone, + Tag2: Clone, + F: FnMut(&RankedDifferentiableTagged) -> DifferentiableTagged, + { + match &self.contents.contents { + DifferentiableContents::Scalar(_, _) => { + panic!("forbidden by the types") + } + DifferentiableContents::Vector(v, rank) => { + assert_eq!(*rank, RANK2); + DifferentiableTagged { + contents: DifferentiableContents::Vector( + v.iter() + .map(|x| f(&(*x).clone().attach_rank::().unwrap())) + .collect(), + RANK2 - 1, + ), + } + } + } + } +} + impl Display for RankedDifferentiableTagged where A: Display, diff --git a/little_learner/src/ext.rs b/little_learner/src/ext.rs new file mode 100644 index 0000000..e6e60b8 --- /dev/null +++ b/little_learner/src/ext.rs @@ -0,0 +1,272 @@ +use crate::auto_diff::{Differentiable, DifferentiableTagged, RankedDifferentiable}; +use std::iter::Sum; +use std::ops::Mul; + +pub fn ext1( + n: usize, + f: &mut F, + t: &DifferentiableTagged, +) -> DifferentiableTagged +where + F: FnMut(&DifferentiableTagged) -> DifferentiableTagged, +{ + if t.rank() == n { + f(t) + } else { + t.map_once_tagged(|x| ext1(n, f, x)) + } +} + +pub fn ext2( + n: usize, + m: usize, + f: &mut F, + t: &DifferentiableTagged, + u: &DifferentiableTagged, +) -> DifferentiableTagged +where + F: FnMut( + &DifferentiableTagged, + &DifferentiableTagged, + ) -> DifferentiableTagged, + A: Clone, + Tag: Clone, + B: Clone, + Tag2: Clone, +{ + if t.rank() == n && u.rank() == m { + f(t, u) + } else if t.rank() == n { + u.map_once_tagged(|eu| ext2(n, m, f, t, eu)) + } else if u.rank() == m { + t.map_once_tagged(|et| ext2(n, m, f, et, u)) + } else if t.rank() == u.rank() { + t.map2_once_tagged(u, |t, u| ext2(n, m, f, t, u)) + } else if t.rank() > u.rank() { + t.map_once_tagged(|et| ext2(n, m, f, et, u)) + } else { + u.map_once_tagged(|eu| ext2(n, m, f, t, eu)) + } +} + +pub fn elementwise_mul_via_ext( + x: &RankedDifferentiable, + y: &RankedDifferentiable, +) -> RankedDifferentiable +where + A: Mul + Sum<::Output> + Clone + Default, +{ + ext2( + 0, + 0, + &mut |x, y| Differentiable::of_scalar(x.clone().into_scalar() * y.clone().into_scalar()), + x.to_unranked_borrow(), + y.to_unranked_borrow(), + ) + .attach_rank::() + .unwrap() +} + +/// Produce the matrix multiplication of the inputs, threading where necessary until the +/// first argument has rank 2 and the second argument has rank 1. +pub fn star_2_1(x: &Differentiable, y: &Differentiable) -> Differentiable +where + T: Clone + Sum + Mul + Default, +{ + ext2( + 2, + 1, + &mut |x, y| { + elementwise_mul_via_ext( + &x.clone().attach_rank::<2>().unwrap(), + &y.clone().attach_rank::<1>().unwrap(), + ) + .to_unranked() + }, + x, + y, + ) +} + +#[cfg(test)] +mod tests { + use crate::auto_diff::{Differentiable, RankedDifferentiable}; + use crate::ext::{elementwise_mul_via_ext, ext1, ext2, star_2_1}; + use crate::not_nan::{to_not_nan_1, to_not_nan_2}; + use crate::scalar::Scalar; + use crate::traits::Zero; + use ordered_float::NotNan; + use std::iter::Sum; + use std::ops::Mul; + + fn zeros_redefined(t: &Differentiable) -> Differentiable + where + A: Zero, + { + ext1( + 0, + &mut |_| Differentiable::of_scalar(Scalar::make(A::zero())), + t, + ) + } + + #[test] + fn define_zeros() { + let shape = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + ])); + let zeros = zeros_redefined(&shape.to_unranked()); + let to_zeros = zeros + .attach_rank::<2>() + .unwrap() + .to_vector() + .iter() + .map(|x| { + (*x).clone() + .to_vector() + .iter() + .map(|x| (*x).clone().to_scalar().clone_real_part().into_inner()) + .collect::>() + }) + .collect::>(); + assert_eq!(to_zeros, [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + } + + fn flatten_2(t: RankedDifferentiable) -> RankedDifferentiable + where + A: Clone, + { + let mut result = Vec::new(); + for v in t.to_unranked_borrow().borrow_vector() { + result.extend((*v.borrow_vector()).clone()) + } + Differentiable::of_vec(result).attach_rank::<1>().unwrap() + } + + #[test] + fn test_flatten_2() { + let input = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([ + [1.0, 0.5], + [3.1, 2.2], + [7.3, 2.1], + ])); + let flattened = flatten_2(input); + let reshaped = flattened + .to_vector() + .iter() + .map(|x| (*x).clone().to_scalar().clone_real_part().into_inner()) + .collect::>(); + assert_eq!(reshaped, [1.0, 0.5, 3.1, 2.2, 7.3, 2.1]) + } + + #[test] + fn test_flatten() { + let flatten = |t: &Differentiable>| { + ext1( + 2, + &mut |t| flatten_2((*t).clone().attach_rank::<2>().unwrap()).to_unranked(), + t, + ) + }; + let input = RankedDifferentiable::of_vector(vec![ + RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([ + [1.0, 0.5], + [3.1, 2.2], + [7.3, 2.1], + ])), + RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([ + [2.9, 3.5], + [0.7, 1.5], + [2.5, 6.4], + ])), + ]); + + let flattened = flatten(&input.to_unranked()) + .attach_rank::<2>() + .unwrap() + .to_vector() + .iter() + .map(|i| { + i.to_unranked_borrow() + .borrow_vector() + .iter() + .map(|j| j.borrow_scalar().clone_real_part().into_inner()) + .collect::>() + }) + .collect::>(); + + assert_eq!( + flattened, + [ + [1.0, 0.5, 3.1, 2.2, 7.3, 2.1], + [2.9, 3.5, 0.7, 1.5, 2.5, 6.4] + ] + ) + } + + #[test] + fn test_star_2_1_a() { + let input1 = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([ + [3.0, 4.0, 5.0], + [7.0, 8.0, 9.0], + ])); + let input2 = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 4.0, 3.0])); + + let output = star_2_1(input1.to_unranked_borrow(), input2.to_unranked_borrow()) + .into_vector() + .iter() + .map(|x| { + x.clone() + .into_vector() + .iter() + .map(|i| i.clone().into_scalar().clone_real_part().into_inner()) + .collect::>() + }) + .collect::>(); + + assert_eq!(output, [[6.0, 16.0, 15.0], [14.0, 32.0, 27.0]]) + } + + #[test] + fn test_star_2_1_b() { + let input1 = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([ + [8.0, 1.0], + [7.0, 3.0], + [5.0, 4.0], + ])); + let input2 = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([ + [6.0, 2.0], + [4.0, 9.0], + [3.0, 8.0], + ])); + + let output = star_2_1(input1.to_unranked_borrow(), input2.to_unranked_borrow()) + .into_vector() + .iter() + .map(|x| { + x.clone() + .into_vector() + .iter() + .map(|i| { + i.clone() + .into_vector() + .iter() + .map(|i| i.borrow_scalar().clone_real_part().into_inner()) + .collect::>() + }) + .collect::>() + }) + .collect::>(); + + assert_eq!( + output, + [ + [[48.0, 2.0], [42.0, 6.0], [30.0, 8.0]], + [[32.0, 9.0], [28.0, 27.0], [20.0, 36.0]], + [[24.0, 8.0], [21.0, 24.0], [15.0, 32.0]] + ] + ) + } +} diff --git a/little_learner/src/lib.rs b/little_learner/src/lib.rs index 73e528b..cdea666 100644 --- a/little_learner/src/lib.rs +++ b/little_learner/src/lib.rs @@ -4,6 +4,7 @@ pub mod auto_diff; pub mod decider; +pub mod ext; pub mod gradient_descent; pub mod hyper; pub mod layer;