diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs index f305936..de60ff8 100644 --- a/little_learner/src/auto_diff.rs +++ b/little_learner/src/auto_diff.rs @@ -12,7 +12,7 @@ where A: Zero, { fn zero() -> DifferentiableHidden { - DifferentiableHidden::Scalar(Scalar::Number(A::zero())) + DifferentiableHidden::Scalar(Scalar::Number(A::zero(), None)) } } @@ -21,7 +21,7 @@ where A: One, { fn one() -> Scalar { - Scalar::Number(A::one()) + Scalar::Number(A::one(), None) } } @@ -46,6 +46,7 @@ where } } +#[derive(Debug)] enum DifferentiableHidden { Scalar(Scalar), Vector(Vec>), @@ -71,9 +72,9 @@ where } impl DifferentiableHidden { - fn map(&self, f: &F) -> DifferentiableHidden + fn map(&self, f: &mut F) -> DifferentiableHidden where - F: Fn(Scalar) -> Scalar, + F: FnMut(Scalar) -> Scalar, A: Clone, { match self { @@ -114,7 +115,7 @@ impl DifferentiableHidden { DifferentiableHidden::Vector( input .iter() - .map(|v| DifferentiableHidden::Scalar(Scalar::Number((*v).clone()))) + .map(|v| DifferentiableHidden::Scalar(Scalar::Number((*v).clone(), None))) .collect(), ) } @@ -131,7 +132,8 @@ where + Div + Zero + One - + Neg, + + Neg + + Display, { fn accumulate_gradients_vec(v: &[DifferentiableHidden], acc: &mut HashMap, A>) { for v in v.iter().rev() { @@ -155,14 +157,14 @@ where let mut acc = HashMap::new(); self.accumulate_gradients(&mut acc); - wrt.map(&|d| match acc.get(&d) { - None => Scalar::Number(A::zero()), - Some(x) => Scalar::Number(x.clone()), + wrt.map(&mut |d| match acc.get(&d) { + None => Scalar::Number(A::zero(), None), + Some(x) => Scalar::Number(x.clone(), None), }) } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Differentiable { contents: DifferentiableHidden, } @@ -205,9 +207,9 @@ impl Differentiable { } } - pub fn map(s: Differentiable, f: &F) -> Differentiable + pub fn map(s: Differentiable, f: &mut F) -> Differentiable where - F: Fn(Scalar) -> Scalar, + F: FnMut(Scalar) -> Scalar, A: Clone, { Differentiable { @@ -252,9 +254,15 @@ impl Differentiable { + Zero + One + Neg - + Eq, + + Eq + + std::fmt::Display, { - let wrt = theta.contents.map(&Scalar::truncate_dual); + let mut i = 0usize; + let wrt = theta.contents.map(&mut |x| { + let result = Scalar::truncate_dual(x, i); + i += 1; + result + }); let after_f = f(Differentiable { contents: wrt.clone(), }); @@ -268,6 +276,8 @@ impl Differentiable { mod tests { use ordered_float::NotNan; + use crate::loss::{l2_loss_2, predict_line_2}; + use super::*; fn extract_scalar<'a, A>(d: &'a DifferentiableHidden) -> &'a A { @@ -283,15 +293,17 @@ mod tests { vec![ DifferentiableHidden::Scalar(Scalar::Number( NotNan::new(3.0).expect("3 is not NaN"), + Some(0usize), )), DifferentiableHidden::Scalar(Scalar::Number( NotNan::new(4.0).expect("4 is not NaN"), + Some(1usize), )), ] .into(), ); - let mapped = v.map(&|x: Scalar>| match x { - Scalar::Number(i) => Scalar::Number(i + NotNan::new(1.0).expect("1 is not NaN")), + let mapped = v.map(&mut |x: Scalar>| match x { + Scalar::Number(i, n) => Scalar::Number(i + NotNan::new(1.0).expect("1 is not NaN"), n), Scalar::Dual(_, _) => panic!("Not hit"), }); @@ -305,4 +317,29 @@ mod tests { assert_eq!(v, [4.0, 5.0]); } + + #[test] + fn test_autodiff() { + let input_vec = of_slice(&[NotNan::::zero(), NotNan::::zero()]); + let xs = [2.0, 1.0, 4.0, 3.0].map(|x| NotNan::new(x).expect("not nan")); + let ys = [1.8, 1.2, 4.2, 3.3].map(|x| NotNan::new(x).expect("not nan")); + let grad = Differentiable::grad( + |x| { + Differentiable::of_vector(vec![of_scalar(l2_loss_2( + predict_line_2, + of_slice(&xs), + of_slice(&ys), + x, + ))]) + }, + input_vec, + ); + + let grad_vec: Vec = Differentiable::to_vector(grad) + .into_iter() + .map(to_scalar) + .map(|x| f64::from(*x.real_part())) + .collect(); + assert_eq!(grad_vec, vec![-63.0, -21.0]); + } } diff --git a/little_learner/src/lib.rs b/little_learner/src/lib.rs index 49511b1..cf17061 100644 --- a/little_learner/src/lib.rs +++ b/little_learner/src/lib.rs @@ -3,6 +3,7 @@ pub mod auto_diff; pub mod expr_syntax_tree; +pub mod loss; pub mod scalar; pub mod tensor; pub mod traits; diff --git a/little_learner/src/loss.rs b/little_learner/src/loss.rs new file mode 100644 index 0000000..ad04f25 --- /dev/null +++ b/little_learner/src/loss.rs @@ -0,0 +1,93 @@ +use std::{ + iter::Sum, + ops::{Add, Mul, Neg}, +}; + +use crate::{ + auto_diff::{of_scalar, to_scalar, Differentiable}, + scalar::Scalar, + traits::{One, Zero}, +}; + +pub fn square(x: &A) -> A +where + A: Mul + Clone, +{ + x.clone() * x.clone() +} + +pub fn dot_2( + x: &Differentiable, + y: &Differentiable, +) -> Differentiable +where + A: Mul + Sum<::Output> + Copy + Default, +{ + Differentiable::map2(x, y, &|x, y| x.clone() * y.clone()) +} + +fn squared_2(x: &Differentiable) -> Differentiable +where + A: Mul + Copy + Default, +{ + Differentiable::map2(x, x, &|x, y| x.clone() * y.clone()) +} + +fn sum_2(x: Differentiable) -> Scalar +where + A: Sum + Copy + Add + Zero, +{ + Differentiable::to_vector(x) + .into_iter() + .map(to_scalar) + .sum() +} + +fn l2_norm_2(prediction: &Differentiable, data: &Differentiable) -> Scalar +where + A: Sum + Mul + Copy + Default + Neg + Add + Zero + Neg, +{ + let diff = Differentiable::map2(prediction, data, &|x, y| x.clone() - y.clone()); + sum_2(squared_2(&diff)) +} + +pub fn l2_loss_2( + target: F, + data_xs: Differentiable, + data_ys: Differentiable, + params: Params, +) -> Scalar +where + F: Fn(Differentiable, Params) -> Differentiable, + A: Sum + Mul + Copy + Default + Neg + Add + Zero, +{ + let pred_ys = target(data_xs, params); + l2_norm_2(&pred_ys, &data_ys) +} + +pub fn predict_line_2( + xs: Differentiable, + theta: Differentiable, +) -> Differentiable +where + A: Mul + Add + Sum<::Output> + Copy + Default + One + Zero, +{ + let xs = Differentiable::to_vector(xs) + .into_iter() + .map(|v| to_scalar(v)); + let mut result = vec![]; + for x in xs { + let left_arg = Differentiable::of_vector(vec![ + of_scalar(x.clone()), + of_scalar( as One>::one()), + ]); + let dotted = of_scalar( + Differentiable::to_vector(dot_2(&left_arg, &theta)) + .iter() + .map(|x| to_scalar((*x).clone())) + .sum(), + ); + result.push(dotted); + } + Differentiable::of_vector(result) +} diff --git a/little_learner/src/scalar.rs b/little_learner/src/scalar.rs index 01b9014..e20c2ed 100644 --- a/little_learner/src/scalar.rs +++ b/little_learner/src/scalar.rs @@ -7,7 +7,7 @@ use std::{ ops::{Add, AddAssign, Div, Mul, Neg, Sub}, }; -#[derive(Clone, Hash, PartialEq, Eq)] +#[derive(Clone, Hash, PartialEq, Eq, Debug)] pub enum LinkData { Addition(Box>, Box>), Neg(Box>), @@ -16,9 +16,9 @@ pub enum LinkData { Log(Box>), } -#[derive(Clone, Hash, PartialEq, Eq)] +#[derive(Clone, Hash, PartialEq, Eq, Debug)] pub enum Link { - EndOfLink, + EndOfLink(Option), Link(LinkData), } @@ -28,7 +28,8 @@ where { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Link::EndOfLink => f.write_str(""), + Link::EndOfLink(Some(i)) => f.write_fmt(format_args!("", *i)), + Link::EndOfLink(None) => f.write_str(""), Link::Link(LinkData::Addition(left, right)) => { f.write_fmt(format_args!("({} + {})", left.as_ref(), right.as_ref())) } @@ -59,7 +60,7 @@ impl Link { + One, { match self { - Link::EndOfLink => match acc.entry(d.clone()) { + Link::EndOfLink(_) => match acc.entry(d.clone()) { Entry::Occupied(mut o) => { let entry = o.get_mut(); *entry += z; @@ -113,9 +114,9 @@ impl Link { } } -#[derive(Clone, Hash, PartialEq, Eq)] +#[derive(Clone, Hash, PartialEq, Eq, Debug)] pub enum Scalar { - Number(A), + Number(A, Option), // The value, and the link. Dual(A, Link), } @@ -125,7 +126,7 @@ where A: Zero, { fn zero() -> Self { - Scalar::Number(A::zero()) + Scalar::Number(A::zero(), None) } } @@ -198,7 +199,7 @@ where impl Scalar { pub fn real_part(&self) -> &A { match self { - Scalar::Number(a) => a, + Scalar::Number(a, _) => a, Scalar::Dual(a, _) => a, } } @@ -208,7 +209,7 @@ impl Scalar { A: Clone, { match self { - Scalar::Number(a) => (*a).clone(), + Scalar::Number(a, _) => (*a).clone(), Scalar::Dual(a, _) => (*a).clone(), } } @@ -216,7 +217,7 @@ impl Scalar { pub fn link(self) -> Link { match self { Scalar::Dual(_, link) => link, - Scalar::Number(_) => Link::EndOfLink, + Scalar::Number(_, i) => Link::EndOfLink(i), } } @@ -226,15 +227,15 @@ impl Scalar { { match self { Scalar::Dual(_, data) => data.clone(), - Scalar::Number(_) => Link::EndOfLink, + Scalar::Number(_, i) => Link::EndOfLink(*i), } } - pub fn truncate_dual(self) -> Scalar + pub fn truncate_dual(self, index: usize) -> Scalar where A: Clone, { - Scalar::Dual(self.clone_real_part(), Link::EndOfLink) + Scalar::Dual(self.clone_real_part(), Link::EndOfLink(Some(index))) } } @@ -244,8 +245,9 @@ where { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Scalar::Number(n) => f.write_fmt(format_args!("{}", n)), - Scalar::Dual(n, link) => f.write_fmt(format_args!("{}, link: {}", n, link)), + Scalar::Number(n, Some(index)) => f.write_fmt(format_args!("{}_{}", n, index)), + Scalar::Number(n, None) => f.write_fmt(format_args!("{}", n)), + Scalar::Dual(n, link) => f.write_fmt(format_args!("<{}, link: {}>", n, link)), } } } diff --git a/little_learner_app/src/main.rs b/little_learner_app/src/main.rs index 5c6aad7..be1d840 100644 --- a/little_learner_app/src/main.rs +++ b/little_learner_app/src/main.rs @@ -3,110 +3,53 @@ mod with_tensor; -use little_learner::auto_diff::{of_scalar, of_slice, to_scalar, Differentiable}; -use little_learner::scalar::Scalar; -use little_learner::traits::{One, Zero}; -use ordered_float::NotNan; +use little_learner::auto_diff::{of_scalar, of_slice, Differentiable}; -use std::iter::Sum; -use std::ops::{Add, Mul, Neg}; +use little_learner::loss::{l2_loss_2, predict_line_2, square}; +use little_learner::traits::Zero; +use ordered_float::NotNan; use crate::with_tensor::{l2_loss, predict_line}; -fn dot_2( - x: &Differentiable, - y: &Differentiable, -) -> Differentiable -where - A: Mul + Sum<::Output> + Copy + Default, -{ - Differentiable::map2(x, y, &|x, y| x.clone() * y.clone()) -} - -fn squared_2(x: &Differentiable) -> Differentiable -where - A: Mul + Copy + Default, -{ - Differentiable::map2(x, x, &|x, y| x.clone() * y.clone()) -} - -fn sum_2(x: Differentiable) -> Scalar -where - A: Sum + Copy + Add + Zero, -{ - Differentiable::to_vector(x) - .into_iter() - .map(to_scalar) - .sum() -} - -fn l2_norm_2(prediction: &Differentiable, data: &Differentiable) -> Scalar -where - A: Sum + Mul + Copy + Default + Neg + Add + Zero + Neg, -{ - let diff = Differentiable::map2(prediction, data, &|x, y| x.clone() - y.clone()); - sum_2(squared_2(&diff)) -} - -pub fn l2_loss_2( - target: F, - data_xs: Differentiable, - data_ys: Differentiable, - params: Params, -) -> Scalar -where - F: Fn(Differentiable, Params) -> Differentiable, - A: Sum + Mul + Copy + Default + Neg + Add + Zero, -{ - let pred_ys = target(data_xs, params); - l2_norm_2(&pred_ys, &data_ys) -} - -fn predict_line_2(xs: Differentiable, theta: Differentiable) -> Differentiable -where - A: Mul + Add + Sum<::Output> + Copy + Default + One + Zero, -{ - let xs = Differentiable::to_vector(xs) - .into_iter() - .map(|v| to_scalar(v)); - let mut result = vec![]; - for x in xs { - let left_arg = Differentiable::of_vector(vec![ - of_scalar(x.clone()), - of_scalar( as One>::one()), - ]); - let dotted = Differentiable::to_vector(dot_2(&left_arg, &theta)); - result.push(dotted[0].clone()); - } - Differentiable::of_vector(result) -} - -fn square(x: &A) -> A -where - A: Mul + Clone, -{ - x.clone() * x.clone() +#[allow(dead_code)] +fn l2_loss_non_autodiff_example() { + let xs = [2.0, 1.0, 4.0, 3.0]; + let ys = [1.8, 1.2, 4.2, 3.3]; + let loss = l2_loss(predict_line, &xs, &ys, &[0.0099, 0.0]); + println!("{:?}", loss); } fn main() { - let loss = l2_loss( - predict_line, - &[2.0, 1.0, 4.0, 3.0], - &[1.8, 1.2, 4.2, 3.3], - &[0.0099, 0.0], - ); - println!("{:?}", loss); + let input_vec = of_slice(&[NotNan::new(27.0).expect("not nan")]); + + let grad = Differentiable::grad(|x| Differentiable::map(x, &mut |x| square(&x)), input_vec); + println!("Gradient of the x^2 function at x=27: {}", grad); + + let xs = [2.0, 1.0, 4.0, 3.0]; + let ys = [1.8, 1.2, 4.2, 3.3]; let loss = l2_loss_2( predict_line_2, - of_slice(&[2.0, 1.0, 4.0, 3.0]), - of_slice(&[1.8, 1.2, 4.2, 3.3]), - of_slice(&[0.0099, 0.0]), + of_slice(&xs), + of_slice(&ys), + of_slice(&[0.0, 0.0]), ); - println!("{}", loss); + println!("Computation of L2 loss: {}", loss); - let input_vec = of_slice(&[NotNan::new(27.0).expect("not nan")]); + let input_vec = of_slice(&[NotNan::::zero(), NotNan::::zero()]); + let xs = [2.0, 1.0, 4.0, 3.0].map(|x| NotNan::new(x).expect("not nan")); + let ys = [1.8, 1.2, 4.2, 3.3].map(|x| NotNan::new(x).expect("not nan")); + let grad = Differentiable::grad( + |x| { + Differentiable::of_vector(vec![of_scalar(l2_loss_2( + predict_line_2, + of_slice(&xs), + of_slice(&ys), + x, + ))]) + }, + input_vec, + ); - let grad = Differentiable::grad(|x| Differentiable::map(x, &|x| square(&x)), input_vec); println!("{}", grad); }