From 1b738b200a2fb11f6733f4d493ae88c58a5bc9e7 Mon Sep 17 00:00:00 2001 From: Patrick Stevens Date: Sat, 8 Apr 2023 00:50:32 +0100 Subject: [PATCH] Separate implementation (#12) --- Cargo.lock | 4 +- little_learner/src/auto_diff.rs | 333 +++++++++++++++++++++----------- little_learner/src/loss.rs | 4 +- little_learner_app/src/main.rs | 6 +- 4 files changed, 226 insertions(+), 121 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c30cc64..367cbfa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -101,9 +101,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.53" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" dependencies = [ "unicode-ident", ] diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs index 97d59b5..8e41db3 100644 --- a/little_learner/src/auto_diff.rs +++ b/little_learner/src/auto_diff.rs @@ -12,7 +12,9 @@ where A: Zero, { fn zero() -> Differentiable { - Differentiable::Scalar(Scalar::Number(A::zero(), None)) + Differentiable { + contents: DifferentiableContents::Scalar(Scalar::Number(A::zero(), None)), + } } } @@ -30,7 +32,21 @@ where A: One, { fn one() -> Differentiable { - Differentiable::Scalar(Scalar::one()) + Differentiable { + contents: DifferentiableContents::Scalar(Scalar::one()), + } + } +} + +impl Clone for DifferentiableContents +where + A: Clone, +{ + fn clone(&self) -> Self { + match self { + Self::Scalar(arg0) => Self::Scalar(arg0.clone()), + Self::Vector(arg0, rank) => Self::Vector(arg0.clone(), *rank), + } } } @@ -39,27 +55,32 @@ where A: Clone, { fn clone(&self) -> Self { - match self { - Self::Scalar(arg0) => Self::Scalar(arg0.clone()), - Self::Vector(arg0) => Self::Vector(arg0.clone()), + Differentiable { + contents: self.contents.clone(), } } } #[derive(Debug)] -pub enum Differentiable { +enum DifferentiableContents { Scalar(Scalar), - Vector(Vec>), + // Contains the rank. + Vector(Vec>, usize), } -impl Display for Differentiable +#[derive(Debug)] +pub struct Differentiable { + contents: DifferentiableContents, +} + +impl Display for DifferentiableContents where A: Display, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)), - Differentiable::Vector(v) => { + DifferentiableContents::Scalar(s) => f.write_fmt(format_args!("{}", s)), + DifferentiableContents::Vector(v, _rank) => { f.write_char('[')?; for v in v.iter() { f.write_fmt(format_args!("{}", v))?; @@ -71,106 +92,196 @@ where } } -impl Differentiable { - pub fn map(&self, f: &mut F) -> Differentiable +impl Display for Differentiable +where + A: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}", self.contents)) + } +} + +impl DifferentiableContents { + fn map(&self, f: &mut F) -> DifferentiableContents where F: FnMut(Scalar) -> Scalar, A: Clone, { match self { - Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())), - Differentiable::Vector(slice) => { - Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect()) + DifferentiableContents::Scalar(a) => DifferentiableContents::Scalar(f(a.clone())), + DifferentiableContents::Vector(slice, rank) => { + DifferentiableContents::Vector(slice.iter().map(|x| x.map(f)).collect(), *rank) } } } + fn map2(&self, other: &DifferentiableContents, f: &F) -> DifferentiableContents + where + F: Fn(&Scalar, &Scalar) -> Scalar, + A: Clone, + B: Clone, + { + match (self, other) { + (DifferentiableContents::Scalar(a), DifferentiableContents::Scalar(b)) => { + DifferentiableContents::Scalar(f(a, b)) + } + ( + DifferentiableContents::Vector(slice_a, rank_a), + DifferentiableContents::Vector(slice_b, rank_b), + ) => { + if rank_a != rank_b { + panic!("Unexpectedly different ranks in map2"); + } + DifferentiableContents::Vector( + slice_a + .iter() + .zip(slice_b.iter()) + .map(|(a, b)| a.map2(b, f)) + .collect(), + *rank_a, + ) + } + _ => panic!("Wrong shapes!"), + } + } + + fn of_slice(input: T) -> DifferentiableContents + where + A: Clone, + T: AsRef<[A]>, + { + DifferentiableContents::Vector( + input + .as_ref() + .iter() + .map(|v| Differentiable { + contents: DifferentiableContents::Scalar(Scalar::Number((*v).clone(), None)), + }) + .collect(), + 1, + ) + } + + fn rank(&self) -> usize { + match self { + DifferentiableContents::Scalar(_) => 0, + DifferentiableContents::Vector(_, rank) => *rank, + } + } +} + +impl Differentiable { + pub fn map(&self, f: &mut F) -> Differentiable + where + A: Clone, + F: FnMut(Scalar) -> Scalar, + { + Differentiable { + contents: self.contents.map(f), + } + } + pub fn map2(&self, other: &Differentiable, f: &F) -> Differentiable where F: Fn(&Scalar, &Scalar) -> Scalar, A: Clone, B: Clone, { - match (self, other) { - (Differentiable::Scalar(a), Differentiable::Scalar(b)) => { - Differentiable::Scalar(f(a, b)) - } - (Differentiable::Vector(slice_a), Differentiable::Vector(slice_b)) => { - Differentiable::Vector( - slice_a - .iter() - .zip(slice_b.iter()) - .map(|(a, b)| a.map2(b, f)) - .collect(), - ) - } - _ => panic!("Wrong shapes!"), + Differentiable { + contents: self.contents.map2(&other.contents, f), } } + pub fn attach_rank( + self: Differentiable, + ) -> Option> { + if self.contents.rank() == RANK { + Some(RankedDifferentiable { contents: self }) + } else { + None + } + } + + pub fn of_scalar(s: Scalar) -> Differentiable { + Differentiable { + contents: DifferentiableContents::Scalar(s), + } + } +} + +impl DifferentiableContents { + fn into_scalar(self) -> Scalar { + match self { + DifferentiableContents::Scalar(s) => s, + DifferentiableContents::Vector(_, _) => panic!("not a scalar"), + } + } + + fn into_vector(self) -> Vec> { + match self { + DifferentiableContents::Scalar(_) => panic!("not a vector"), + DifferentiableContents::Vector(v, _) => v, + } + } + + fn borrow_scalar(&self) -> &Scalar { + match self { + DifferentiableContents::Scalar(s) => s, + DifferentiableContents::Vector(_, _) => panic!("not a scalar"), + } + } + + fn borrow_vector(&self) -> &Vec> { + match self { + DifferentiableContents::Scalar(_) => panic!("not a vector"), + DifferentiableContents::Vector(v, _) => v, + } + } +} + +impl Differentiable { + pub fn into_scalar(self) -> Scalar { + self.contents.into_scalar() + } + + pub fn into_vector(self) -> Vec> { + self.contents.into_vector() + } + + pub fn borrow_scalar(&self) -> &Scalar { + self.contents.borrow_scalar() + } + + pub fn borrow_vector(&self) -> &Vec> { + self.contents.borrow_vector() + } + fn of_slice(input: T) -> Differentiable where A: Clone, T: AsRef<[A]>, { - Differentiable::Vector( - input - .as_ref() - .iter() - .map(|v| Differentiable::Scalar(Scalar::Number((*v).clone(), None))) - .collect(), - ) + Differentiable { + contents: DifferentiableContents::of_slice(input), + } + } + + pub fn of_vec(input: Vec>) -> Differentiable { + if input.is_empty() { + panic!("Can't make an empty tensor"); + } + let rank = input[0].rank(); + Differentiable { + contents: DifferentiableContents::Vector(input, 1 + rank), + } } pub fn rank(&self) -> usize { - match self { - Differentiable::Scalar(_) => 0, - Differentiable::Vector(v) => v[0].rank() + 1, - } - } - - pub fn attach_rank( - self: Differentiable, - ) -> Option> { - if self.rank() == RANK { - Some(RankedDifferentiable { contents: self }) - } else { - None - } + self.contents.rank() } } -impl Differentiable { - pub fn into_scalar(self) -> Scalar { - match self { - Differentiable::Scalar(s) => s, - Differentiable::Vector(_) => panic!("not a scalar"), - } - } - - pub fn into_vector(self) -> Vec> { - match self { - Differentiable::Scalar(_) => panic!("not a vector"), - Differentiable::Vector(v) => v, - } - } - - pub fn borrow_scalar(&self) -> &Scalar { - match self { - Differentiable::Scalar(s) => s, - Differentiable::Vector(_) => panic!("not a scalar"), - } - } - - pub fn borrow_vector(&self) -> &Vec> { - match self { - Differentiable::Scalar(_) => panic!("not a vector"), - Differentiable::Vector(v) => v, - } - } -} - -impl Differentiable +impl DifferentiableContents where A: Clone + Eq @@ -185,17 +296,19 @@ where { fn accumulate_gradients_vec(v: &[Differentiable], acc: &mut HashMap, A>) { for v in v.iter().rev() { - v.accumulate_gradients(acc); + v.contents.accumulate_gradients(acc); } } fn accumulate_gradients(&self, acc: &mut HashMap, A>) { match self { - Differentiable::Scalar(y) => { + DifferentiableContents::Scalar(y) => { let k = y.clone_link(); k.invoke(y, A::one(), acc); } - Differentiable::Vector(y) => Differentiable::accumulate_gradients_vec(y, acc), + DifferentiableContents::Vector(y, _rank) => { + DifferentiableContents::accumulate_gradients_vec(y, acc) + } } } @@ -231,15 +344,12 @@ where impl RankedDifferentiable { pub fn to_scalar(self) -> Scalar { - match self.contents { - Differentiable::Scalar(s) => s, - Differentiable::Vector(_) => panic!("not a scalar despite teq that we're a scalar"), - } + self.contents.contents.into_scalar() } pub fn of_scalar(s: Scalar) -> RankedDifferentiable { RankedDifferentiable { - contents: Differentiable::Scalar(s), + contents: Differentiable::of_scalar(s), } } } @@ -251,7 +361,9 @@ impl RankedDifferentiable { T: AsRef<[A]>, { RankedDifferentiable { - contents: Differentiable::of_slice(input), + contents: Differentiable { + contents: DifferentiableContents::of_slice(input), + }, } } } @@ -267,7 +379,7 @@ impl RankedDifferentiable { .map(|x| Differentiable::of_slice(x)) .collect::>(); RankedDifferentiable { - contents: Differentiable::Vector(v), + contents: Differentiable::of_vec(v), } } } @@ -285,7 +397,7 @@ impl RankedDifferentiable { s: Vec>, ) -> RankedDifferentiable { RankedDifferentiable { - contents: Differentiable::Vector(s.into_iter().map(|v| v.contents).collect()), + contents: Differentiable::of_vec(s.into_iter().map(|v| v.contents).collect()), } } @@ -320,13 +432,11 @@ impl RankedDifferentiable { pub fn to_vector( self: RankedDifferentiable, ) -> Vec> { - match self.contents { - Differentiable::Scalar(_) => panic!("not a scalar"), - Differentiable::Vector(v) => v - .into_iter() - .map(|v| RankedDifferentiable { contents: v }) - .collect(), - } + self.contents + .into_vector() + .into_iter() + .map(|v| RankedDifferentiable { contents: v }) + .collect() } } @@ -357,7 +467,7 @@ where }) }); let after_f = f(&wrt); - Differentiable::grad_once(after_f.contents, wrt) + DifferentiableContents::grad_once(after_f.contents.contents, wrt) } #[cfg(test)] @@ -369,21 +479,18 @@ mod tests { use super::*; fn extract_scalar<'a, A>(d: &'a Differentiable) -> &'a A { - match d { - Differentiable::Scalar(a) => &(a.real_part()), - Differentiable::Vector(_) => panic!("not a scalar"), - } + d.borrow_scalar().real_part() } #[test] fn test_map() { - let v = Differentiable::Vector( + let v = Differentiable::of_vec( vec![ - Differentiable::Scalar(Scalar::Number( + Differentiable::of_scalar(Scalar::Number( NotNan::new(3.0).expect("3 is not NaN"), Some(0usize), )), - Differentiable::Scalar(Scalar::Number( + Differentiable::of_scalar(Scalar::Number( NotNan::new(4.0).expect("4 is not NaN"), Some(1usize), )), @@ -395,13 +502,11 @@ mod tests { Scalar::Dual(_, _) => panic!("Not hit"), }); - let v = match mapped { - Differentiable::Scalar(_) => panic!("Not a scalar"), - Differentiable::Vector(v) => v - .iter() - .map(|d| extract_scalar(d).clone()) - .collect::>(), - }; + let v = mapped + .into_vector() + .iter() + .map(|d| extract_scalar(d).clone()) + .collect::>(); assert_eq!(v, [4.0, 5.0]); } diff --git a/little_learner/src/loss.rs b/little_learner/src/loss.rs index 9bcbe43..a2044f6 100644 --- a/little_learner/src/loss.rs +++ b/little_learner/src/loss.rs @@ -126,7 +126,7 @@ where let dotted = RankedDifferentiable::of_scalar( dot_unranked( left_arg.to_unranked_borrow(), - &Differentiable::Vector(theta.to_vec()), + &Differentiable::of_vec(theta.to_vec()), ) .into_vector() .into_iter() @@ -180,7 +180,7 @@ where ); dot_unranked( x_powers.to_unranked_borrow(), - &Differentiable::Vector(theta.to_vec()), + &Differentiable::of_vec(theta.to_vec()), ) .attach_rank::<1>() .expect("wanted a tensor1") diff --git a/little_learner_app/src/main.rs b/little_learner_app/src/main.rs index da18e8e..19f15ad 100644 --- a/little_learner_app/src/main.rs +++ b/little_learner_app/src/main.rs @@ -106,7 +106,7 @@ fn main() { }, [ RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(), - Differentiable::Scalar(Scalar::zero()), + Differentiable::of_scalar(Scalar::zero()), ], hyper.iterations, ) @@ -168,7 +168,7 @@ mod tests { #[test] fn grad_example() { - let input_vec = [Differentiable::Scalar(Scalar::make( + let input_vec = [Differentiable::of_scalar(Scalar::make( NotNan::new(27.0).expect("not nan"), ))]; @@ -362,7 +362,7 @@ mod tests { }, [ RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(), - Differentiable::Scalar(Scalar::zero()), + Differentiable::of_scalar(Scalar::zero()), ], hyper.iterations, )