From 0d2e5eb27753e40bc4c29e58bd47057b4059bc70 Mon Sep 17 00:00:00 2001 From: Patrick Stevens Date: Wed, 29 Mar 2023 21:00:13 +0100 Subject: [PATCH] Add rank parameters to autodiff (#6) --- .gitignore | 1 + flake.nix | 4 +- little_learner/rust-toolchain | 1 + little_learner/src/auto_diff.rs | 443 +++++++++++--------------- little_learner/src/lib.rs | 5 + little_learner/src/scalar.rs | 251 +++++++++++++++ little_learner/src/traits.rs | 43 +++ little_learner_app/src/main.rs | 172 ++++------ little_learner_app/src/with_tensor.rs | 126 ++++++++ rust-toolchain | 1 + 10 files changed, 685 insertions(+), 362 deletions(-) create mode 100644 little_learner/rust-toolchain create mode 100644 little_learner/src/scalar.rs create mode 100644 little_learner/src/traits.rs create mode 100644 little_learner_app/src/with_tensor.rs create mode 100644 rust-toolchain diff --git a/.gitignore b/.gitignore index d175129..ca319f7 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ target/ .idea/ *.iml .vscode/ +.profile* diff --git a/flake.nix b/flake.nix index ccb3806..363782e 100644 --- a/flake.nix +++ b/flake.nix @@ -37,8 +37,8 @@ # Because rust-overlay bundles multiple rust packages into one # derivation, specify that mega-bundle here, so that crate2nix # will use them automatically. - rustc = self.rust-bin.stable.latest.default; - cargo = self.rust-bin.stable.latest.default; + rustc = self.rust-bin.nightly.latest.default; + cargo = self.rust-bin.nightly.latest.default; }) ]; }; diff --git a/little_learner/rust-toolchain b/little_learner/rust-toolchain new file mode 100644 index 0000000..bf867e0 --- /dev/null +++ b/little_learner/rust-toolchain @@ -0,0 +1 @@ +nightly diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs index ed4eac9..f305936 100644 --- a/little_learner/src/auto_diff.rs +++ b/little_learner/src/auto_diff.rs @@ -1,265 +1,64 @@ +use crate::scalar::Scalar; +use crate::traits::{Exp, One, Zero}; use core::hash::Hash; -use ordered_float::NotNan; +use std::collections::HashMap; use std::{ - collections::{hash_map::Entry, HashMap}, fmt::{Display, Write}, - ops::{Add, AddAssign, Div, Mul}, + ops::{AddAssign, Div, Mul, Neg}, }; -pub trait Zero { - fn zero() -> Self; -} - -pub trait One { - fn one() -> Self; -} - -impl Zero for f64 { - fn zero() -> Self { - 0.0 - } -} - -impl One for f64 { - fn one() -> Self { - 1.0 - } -} - -impl Zero for NotNan { - fn zero() -> Self { - NotNan::new(0.0).unwrap() - } -} - -impl One for NotNan { - fn one() -> Self { - NotNan::new(1.0).unwrap() - } -} - -impl Zero for Differentiable +impl Zero for DifferentiableHidden where A: Zero, { - fn zero() -> Differentiable { - Differentiable::Scalar(Scalar::Number(A::zero())) + fn zero() -> DifferentiableHidden { + DifferentiableHidden::Scalar(Scalar::Number(A::zero())) } } -impl One for Differentiable +impl One for Scalar where A: One, { - fn one() -> Differentiable { - Differentiable::Scalar(Scalar::Number(A::one())) + fn one() -> Scalar { + Scalar::Number(A::one()) } } -pub trait Exp { - fn exp(self) -> Self; -} - -impl Exp for NotNan { - fn exp(self) -> Self { - NotNan::new(f64::exp(self.into_inner())).expect("expected a non-NaN") - } -} - -#[derive(Clone, Hash, PartialEq, Eq)] -pub enum LinkData { - Addition(Box>, Box>), - Mul(Box>, Box>), - Exponent(Box>), - Log(Box>), -} - -#[derive(Clone, Hash, PartialEq, Eq)] -pub enum Link { - EndOfLink, - Link(LinkData), -} - -impl Display for Link +impl One for DifferentiableHidden where - A: Display, + A: One, { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Link::EndOfLink => f.write_str(""), - Link::Link(LinkData::Addition(left, right)) => { - f.write_fmt(format_args!("({} + {})", left.as_ref(), right.as_ref())) - } - Link::Link(LinkData::Mul(left, right)) => { - f.write_fmt(format_args!("({} * {})", left.as_ref(), right.as_ref())) - } - Link::Link(LinkData::Exponent(arg)) => { - f.write_fmt(format_args!("exp({})", arg.as_ref())) - } - Link::Link(LinkData::Log(arg)) => f.write_fmt(format_args!("log({})", arg.as_ref())), - } + fn one() -> DifferentiableHidden { + DifferentiableHidden::Scalar(Scalar::one()) } } -impl Link { - fn invoke(self, d: &Scalar, z: A, acc: &mut HashMap, A>) - where - A: Eq + Hash + AddAssign + Clone + Exp + Mul + Div + Zero + One, - { - match self { - Link::EndOfLink => match acc.entry(d.clone()) { - Entry::Occupied(mut o) => { - let entry = o.get_mut(); - *entry += z; - } - Entry::Vacant(v) => { - v.insert(z); - } - }, - Link::Link(data) => { - match data { - LinkData::Addition(left, right) => { - // The `z` here reflects the fact that dx/dx = 1, so it's 1 * z. - left.as_ref().clone_link().invoke(&left, z.clone(), acc); - right.as_ref().clone_link().invoke(&right, z, acc); - } - LinkData::Exponent(arg) => { - // d/dx (e^x) = exp x, so exp z * z. - arg.as_ref().clone_link().invoke( - &arg, - z * arg.clone_real_part().exp(), - acc, - ); - } - LinkData::Mul(left, right) => { - // d/dx(f g) = f dg/dx + g df/dx - left.as_ref().clone_link().invoke( - &left, - right.clone_real_part() * z.clone(), - acc, - ); - right - .as_ref() - .clone_link() - .invoke(&right, left.clone_real_part() * z, acc); - } - LinkData::Log(arg) => { - // d/dx(log y) = 1/y dy/dx - arg.as_ref().clone_link().invoke( - &arg, - A::one() / arg.clone_real_part() * z, - acc, - ); - } - } - } - } - } -} - -#[derive(Clone, Hash, PartialEq, Eq)] -pub enum Scalar { - Number(A), - // The value, and the link. - Dual(A, Link), -} - -impl Add for Scalar +impl Clone for DifferentiableHidden where - A: Add + Clone, + A: Clone, { - type Output = Scalar; - - fn add(self, rhs: Self) -> Self::Output { - Scalar::Dual( - self.clone_real_part() + rhs.clone_real_part(), - Link::Link(LinkData::Addition(Box::new(self), Box::new(rhs))), - ) - } -} - -impl Mul for Scalar -where - A: Mul + Clone, -{ - type Output = Scalar; - - fn mul(self, rhs: Self) -> Self::Output { - Scalar::Dual( - self.clone_real_part() * rhs.clone_real_part(), - Link::Link(LinkData::Mul(Box::new(self), Box::new(rhs))), - ) - } -} - -impl Scalar { - pub fn real_part(&self) -> &A { + fn clone(&self) -> Self { match self { - Scalar::Number(a) => a, - Scalar::Dual(a, _) => a, - } - } - - fn clone_real_part(&self) -> A - where - A: Clone, - { - match self { - Scalar::Number(a) => (*a).clone(), - Scalar::Dual(a, _) => (*a).clone(), - } - } - - pub fn link(self) -> Link { - match self { - Scalar::Dual(_, link) => link, - Scalar::Number(_) => Link::EndOfLink, - } - } - - fn clone_link(&self) -> Link - where - A: Clone, - { - match self { - Scalar::Dual(_, data) => data.clone(), - Scalar::Number(_) => Link::EndOfLink, - } - } - - fn truncate_dual(self) -> Scalar - where - A: Clone, - { - Scalar::Dual(self.clone_real_part(), Link::EndOfLink) - } -} - -impl Display for Scalar -where - A: Display, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Scalar::Number(n) => f.write_fmt(format_args!("{}", n)), - Scalar::Dual(n, link) => f.write_fmt(format_args!("{}, link: {}", n, link)), + Self::Scalar(arg0) => Self::Scalar(arg0.clone()), + Self::Vector(arg0) => Self::Vector(arg0.clone()), } } } -pub enum Differentiable { +enum DifferentiableHidden { Scalar(Scalar), - Vector(Box<[Differentiable]>), + Vector(Vec>), } -impl Display for Differentiable +impl Display for DifferentiableHidden where A: Display, { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)), - Differentiable::Vector(v) => { + DifferentiableHidden::Scalar(s) => f.write_fmt(format_args!("{}", s)), + DifferentiableHidden::Vector(v) => { f.write_char('[')?; for v in v.iter() { f.write_fmt(format_args!("{}", v))?; @@ -271,26 +70,70 @@ where } } -impl Differentiable { - pub fn map(&self, f: &F) -> Differentiable +impl DifferentiableHidden { + fn map(&self, f: &F) -> DifferentiableHidden where F: Fn(Scalar) -> Scalar, A: Clone, { match self { - Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())), - Differentiable::Vector(slice) => { - Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect()) + DifferentiableHidden::Scalar(a) => DifferentiableHidden::Scalar(f(a.clone())), + DifferentiableHidden::Vector(slice) => { + DifferentiableHidden::Vector(slice.iter().map(|x| x.map(f)).collect()) } } } + + fn map2(&self, other: &DifferentiableHidden, f: &F) -> DifferentiableHidden + where + F: Fn(&Scalar, &Scalar) -> Scalar, + A: Clone, + B: Clone, + { + match (self, other) { + (DifferentiableHidden::Scalar(a), DifferentiableHidden::Scalar(b)) => { + DifferentiableHidden::Scalar(f(a, b)) + } + (DifferentiableHidden::Vector(slice_a), DifferentiableHidden::Vector(slice_b)) => { + DifferentiableHidden::Vector( + slice_a + .iter() + .zip(slice_b.iter()) + .map(|(a, b)| a.map2(b, f)) + .collect(), + ) + } + _ => panic!("Wrong shapes!"), + } + } + + fn of_slice(input: &[A]) -> DifferentiableHidden + where + A: Clone, + { + DifferentiableHidden::Vector( + input + .iter() + .map(|v| DifferentiableHidden::Scalar(Scalar::Number((*v).clone()))) + .collect(), + ) + } } -impl Differentiable +impl DifferentiableHidden where - A: Clone + Eq + Hash + AddAssign + Mul + Exp + Div + Zero + One, + A: Clone + + Eq + + Hash + + AddAssign + + Mul + + Exp + + Div + + Zero + + One + + Neg, { - fn accumulate_gradients_vec(v: &[Differentiable], acc: &mut HashMap, A>) { + fn accumulate_gradients_vec(v: &[DifferentiableHidden], acc: &mut HashMap, A>) { for v in v.iter().rev() { v.accumulate_gradients(acc); } @@ -298,15 +141,17 @@ where fn accumulate_gradients(&self, acc: &mut HashMap, A>) { match self { - Differentiable::Scalar(y) => { + DifferentiableHidden::Scalar(y) => { let k = y.clone_link(); k.invoke(y, A::one(), acc); } - Differentiable::Vector(y) => Differentiable::accumulate_gradients_vec(y, acc), + DifferentiableHidden::Vector(y) => { + DifferentiableHidden::accumulate_gradients_vec(y, acc) + } } } - fn grad_once(self, wrt: Differentiable) -> Differentiable { + fn grad_once(self, wrt: &DifferentiableHidden) -> DifferentiableHidden { let mut acc = HashMap::new(); self.accumulate_gradients(&mut acc); @@ -315,34 +160,133 @@ where Some(x) => Scalar::Number(x.clone()), }) } +} - pub fn grad(f: F, theta: Differentiable) -> Differentiable +#[derive(Clone)] +pub struct Differentiable { + contents: DifferentiableHidden, +} + +impl Display for Differentiable +where + A: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.contents, f) + } +} + +pub fn of_scalar(s: Scalar) -> Differentiable { + Differentiable { + contents: DifferentiableHidden::Scalar(s), + } +} + +pub fn to_scalar(s: Differentiable) -> Scalar { + match s.contents { + DifferentiableHidden::Scalar(s) => s, + DifferentiableHidden::Vector(_) => panic!("not a vector"), + } +} + +pub fn of_slice(input: &[A]) -> Differentiable +where + A: Clone, +{ + Differentiable { + contents: DifferentiableHidden::of_slice(input), + } +} + +impl Differentiable { + pub fn of_vector(s: Vec>) -> Differentiable { + Differentiable { + contents: DifferentiableHidden::Vector(s.into_iter().map(|v| v.contents).collect()), + } + } + + pub fn map(s: Differentiable, f: &F) -> Differentiable where - F: Fn(&Differentiable) -> Differentiable, + F: Fn(Scalar) -> Scalar, + A: Clone, { - let wrt = theta.map(&Scalar::truncate_dual); - let after_f = f(&wrt); - Differentiable::grad_once(after_f, wrt) + Differentiable { + contents: DifferentiableHidden::map(&s.contents, f), + } + } + + pub fn map2( + self: &Differentiable, + other: &Differentiable, + f: &F, + ) -> Differentiable + where + F: Fn(&Scalar, &Scalar) -> Scalar, + A: Clone, + B: Clone, + { + Differentiable { + contents: DifferentiableHidden::map2(&self.contents, &other.contents, f), + } + } + + pub fn to_vector(s: Differentiable) -> Vec> { + match s.contents { + DifferentiableHidden::Scalar(_) => panic!("not a scalar"), + DifferentiableHidden::Vector(v) => v + .into_iter() + .map(|v| Differentiable { contents: v }) + .collect(), + } + } + + pub fn grad(f: F, theta: Differentiable) -> Differentiable + where + F: Fn(Differentiable) -> Differentiable, + A: Clone + + Hash + + AddAssign + + Mul + + Exp + + Div + + Zero + + One + + Neg + + Eq, + { + let wrt = theta.contents.map(&Scalar::truncate_dual); + let after_f = f(Differentiable { + contents: wrt.clone(), + }); + Differentiable { + contents: DifferentiableHidden::grad_once(after_f.contents, &wrt), + } } } #[cfg(test)] mod tests { + use ordered_float::NotNan; + use super::*; - fn extract_scalar<'a, A>(d: &'a Differentiable) -> &'a A { + fn extract_scalar<'a, A>(d: &'a DifferentiableHidden) -> &'a A { match d { - Differentiable::Scalar(a) => &(a.real_part()), - Differentiable::Vector(_) => panic!("not a scalar"), + DifferentiableHidden::Scalar(a) => &(a.real_part()), + DifferentiableHidden::Vector(_) => panic!("not a scalar"), } } #[test] fn test_map() { - let v = Differentiable::Vector( + let v = DifferentiableHidden::Vector( vec![ - Differentiable::Scalar(Scalar::Number(NotNan::new(3.0).expect("3 is not NaN"))), - Differentiable::Scalar(Scalar::Number(NotNan::new(4.0).expect("4 is not NaN"))), + DifferentiableHidden::Scalar(Scalar::Number( + NotNan::new(3.0).expect("3 is not NaN"), + )), + DifferentiableHidden::Scalar(Scalar::Number( + NotNan::new(4.0).expect("4 is not NaN"), + )), ] .into(), ); @@ -352,9 +296,8 @@ mod tests { }); let v = match mapped { - Differentiable::Scalar(_) => panic!("Not a scalar"), - Differentiable::Vector(v) => v - .as_ref() + DifferentiableHidden::Scalar(_) => panic!("Not a scalar"), + DifferentiableHidden::Vector(v) => v .iter() .map(|d| extract_scalar(d).clone()) .collect::>(), diff --git a/little_learner/src/lib.rs b/little_learner/src/lib.rs index 7590a38..49511b1 100644 --- a/little_learner/src/lib.rs +++ b/little_learner/src/lib.rs @@ -1,3 +1,8 @@ +#![allow(incomplete_features)] +#![feature(generic_const_exprs)] + pub mod auto_diff; pub mod expr_syntax_tree; +pub mod scalar; pub mod tensor; +pub mod traits; diff --git a/little_learner/src/scalar.rs b/little_learner/src/scalar.rs new file mode 100644 index 0000000..01b9014 --- /dev/null +++ b/little_learner/src/scalar.rs @@ -0,0 +1,251 @@ +use crate::traits::{Exp, One, Zero}; +use core::hash::Hash; +use std::{ + collections::{hash_map::Entry, HashMap}, + fmt::Display, + iter::Sum, + ops::{Add, AddAssign, Div, Mul, Neg, Sub}, +}; + +#[derive(Clone, Hash, PartialEq, Eq)] +pub enum LinkData { + Addition(Box>, Box>), + Neg(Box>), + Mul(Box>, Box>), + Exponent(Box>), + Log(Box>), +} + +#[derive(Clone, Hash, PartialEq, Eq)] +pub enum Link { + EndOfLink, + Link(LinkData), +} + +impl Display for Link +where + A: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Link::EndOfLink => f.write_str(""), + Link::Link(LinkData::Addition(left, right)) => { + f.write_fmt(format_args!("({} + {})", left.as_ref(), right.as_ref())) + } + Link::Link(LinkData::Neg(arg)) => f.write_fmt(format_args!("(-{})", arg.as_ref())), + Link::Link(LinkData::Mul(left, right)) => { + f.write_fmt(format_args!("({} * {})", left.as_ref(), right.as_ref())) + } + Link::Link(LinkData::Exponent(arg)) => { + f.write_fmt(format_args!("exp({})", arg.as_ref())) + } + Link::Link(LinkData::Log(arg)) => f.write_fmt(format_args!("log({})", arg.as_ref())), + } + } +} + +impl Link { + pub fn invoke(self, d: &Scalar, z: A, acc: &mut HashMap, A>) + where + A: Eq + + Hash + + AddAssign + + Clone + + Exp + + Mul + + Div + + Neg + + Zero + + One, + { + match self { + Link::EndOfLink => match acc.entry(d.clone()) { + Entry::Occupied(mut o) => { + let entry = o.get_mut(); + *entry += z; + } + Entry::Vacant(v) => { + v.insert(z); + } + }, + Link::Link(data) => { + match data { + LinkData::Addition(left, right) => { + // The `z` here reflects the fact that dx/dx = 1, so it's 1 * z. + left.as_ref().clone_link().invoke(&left, z.clone(), acc); + right.as_ref().clone_link().invoke(&right, z, acc); + } + LinkData::Exponent(arg) => { + // d/dx (e^x) = exp x, so exp z * z. + arg.as_ref().clone_link().invoke( + &arg, + z * arg.clone_real_part().exp(), + acc, + ); + } + LinkData::Mul(left, right) => { + // d/dx(f g) = f dg/dx + g df/dx + left.as_ref().clone_link().invoke( + &left, + right.clone_real_part() * z.clone(), + acc, + ); + right + .as_ref() + .clone_link() + .invoke(&right, left.clone_real_part() * z, acc); + } + LinkData::Log(arg) => { + // d/dx(log y) = 1/y dy/dx + arg.as_ref().clone_link().invoke( + &arg, + A::one() / arg.clone_real_part() * z, + acc, + ); + } + LinkData::Neg(arg) => { + // d/dx(-y) = - dy/dx + arg.as_ref().clone_link().invoke(&arg, -z, acc); + } + } + } + } + } +} + +#[derive(Clone, Hash, PartialEq, Eq)] +pub enum Scalar { + Number(A), + // The value, and the link. + Dual(A, Link), +} + +impl Zero for Scalar +where + A: Zero, +{ + fn zero() -> Self { + Scalar::Number(A::zero()) + } +} + +impl Add for Scalar +where + A: Add + Clone, +{ + type Output = Scalar; + + fn add(self, rhs: Self) -> Self::Output { + Scalar::Dual( + self.clone_real_part() + rhs.clone_real_part(), + Link::Link(LinkData::Addition(Box::new(self), Box::new(rhs))), + ) + } +} + +impl Neg for Scalar +where + A: Neg + Clone, +{ + type Output = Scalar; + + fn neg(self) -> Self::Output { + Scalar::Dual( + -self.clone_real_part(), + Link::Link(LinkData::Neg(Box::new(self))), + ) + } +} + +impl Sub for Scalar +where + A: Add + Neg + Clone, +{ + type Output = Scalar; + + fn sub(self, rhs: Self) -> Self::Output { + self + (-rhs) + } +} + +impl Mul for Scalar +where + A: Mul + Clone, +{ + type Output = Scalar; + + fn mul(self, rhs: Self) -> Self::Output { + Scalar::Dual( + self.clone_real_part() * rhs.clone_real_part(), + Link::Link(LinkData::Mul(Box::new(self), Box::new(rhs))), + ) + } +} + +impl Sum for Scalar +where + A: Zero + Add + Clone, +{ + fn sum>(iter: I) -> Self { + let mut answer = Zero::zero(); + for i in iter { + answer = answer + i; + } + answer + } +} + +impl Scalar { + pub fn real_part(&self) -> &A { + match self { + Scalar::Number(a) => a, + Scalar::Dual(a, _) => a, + } + } + + pub fn clone_real_part(&self) -> A + where + A: Clone, + { + match self { + Scalar::Number(a) => (*a).clone(), + Scalar::Dual(a, _) => (*a).clone(), + } + } + + pub fn link(self) -> Link { + match self { + Scalar::Dual(_, link) => link, + Scalar::Number(_) => Link::EndOfLink, + } + } + + pub fn clone_link(&self) -> Link + where + A: Clone, + { + match self { + Scalar::Dual(_, data) => data.clone(), + Scalar::Number(_) => Link::EndOfLink, + } + } + + pub fn truncate_dual(self) -> Scalar + where + A: Clone, + { + Scalar::Dual(self.clone_real_part(), Link::EndOfLink) + } +} + +impl Display for Scalar +where + A: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Scalar::Number(n) => f.write_fmt(format_args!("{}", n)), + Scalar::Dual(n, link) => f.write_fmt(format_args!("{}, link: {}", n, link)), + } + } +} diff --git a/little_learner/src/traits.rs b/little_learner/src/traits.rs new file mode 100644 index 0000000..c87401e --- /dev/null +++ b/little_learner/src/traits.rs @@ -0,0 +1,43 @@ +use ordered_float::NotNan; + +pub trait Exp { + fn exp(self) -> Self; +} + +impl Exp for NotNan { + fn exp(self) -> Self { + NotNan::new(f64::exp(self.into_inner())).expect("expected a non-NaN") + } +} + +pub trait Zero { + fn zero() -> Self; +} + +pub trait One { + fn one() -> Self; +} + +impl Zero for f64 { + fn zero() -> Self { + 0.0 + } +} + +impl One for f64 { + fn one() -> Self { + 1.0 + } +} + +impl Zero for NotNan { + fn zero() -> Self { + NotNan::new(0.0).unwrap() + } +} + +impl One for NotNan { + fn one() -> Self { + NotNan::new(1.0).unwrap() + } +} diff --git a/little_learner_app/src/main.rs b/little_learner_app/src/main.rs index 5fcba56..5c6aad7 100644 --- a/little_learner_app/src/main.rs +++ b/little_learner_app/src/main.rs @@ -1,93 +1,90 @@ -use little_learner::auto_diff::{Differentiable, Scalar}; -use little_learner::tensor; -use little_learner::tensor::{extension2, Extensible2}; +#![allow(incomplete_features)] +#![feature(generic_const_exprs)] + +mod with_tensor; + +use little_learner::auto_diff::{of_scalar, of_slice, to_scalar, Differentiable}; +use little_learner::scalar::Scalar; +use little_learner::traits::{One, Zero}; use ordered_float::NotNan; use std::iter::Sum; -use std::ops::{Mul, Sub}; +use std::ops::{Add, Mul, Neg}; -type Point = [A; N]; +use crate::with_tensor::{l2_loss, predict_line}; -type Parameters = [Point; M]; - -fn dot_points(x: &Point, y: &Point) -> A +fn dot_2( + x: &Differentiable, + y: &Differentiable, +) -> Differentiable where - A: Sum<::Output> + Copy + Default + Mul + Extensible2, + A: Mul + Sum<::Output> + Copy + Default, { - extension2(x, y, |&x, &y| x * y).into_iter().sum() + Differentiable::map2(x, y, &|x, y| x.clone() * y.clone()) } -fn dot(x: &Point, y: &Parameters) -> Point +fn squared_2(x: &Differentiable) -> Differentiable where - A: Mul + Sum<::Output> + Copy + Default + Extensible2, + A: Mul + Copy + Default, { - let mut result = [Default::default(); M]; - for (i, coord) in y.iter().map(|y| dot_points(x, y)).enumerate() { - result[i] = coord; - } - result + Differentiable::map2(x, x, &|x, y| x.clone() * y.clone()) } -fn sum(x: &tensor!(A, N)) -> A +fn sum_2(x: Differentiable) -> Scalar where - A: Sum + Copy, + A: Sum + Copy + Add + Zero, { - A::sum(x.iter().cloned()) + Differentiable::to_vector(x) + .into_iter() + .map(to_scalar) + .sum() } -fn squared(x: &tensor!(A, N)) -> tensor!(A, N) +fn l2_norm_2(prediction: &Differentiable, data: &Differentiable) -> Scalar where - A: Mul + Extensible2 + Copy + Default, + A: Sum + Mul + Copy + Default + Neg + Add + Zero + Neg, { - extension2(x, x, |&a, &b| (a * b)) + let diff = Differentiable::map2(prediction, data, &|x, y| x.clone() - y.clone()); + sum_2(squared_2(&diff)) } -fn l2_norm(prediction: &tensor!(A, N), data: &tensor!(A, N)) -> A -where - A: Sum + Mul + Extensible2 + Copy + Default + Sub, -{ - let diff = extension2(prediction, data, |&x, &y| x - y); - sum(&squared(&diff)) -} - -pub fn l2_loss( +pub fn l2_loss_2( target: F, - data_xs: &tensor!(A, N), - data_ys: &tensor!(A, N), - params: &Params, -) -> A + data_xs: Differentiable, + data_ys: Differentiable, + params: Params, +) -> Scalar where - F: Fn(&tensor!(A, N), &Params) -> tensor!(A, N), - A: Sum + Mul + Extensible2 + Copy + Default + Sub, + F: Fn(Differentiable, Params) -> Differentiable, + A: Sum + Mul + Copy + Default + Neg + Add + Zero, { let pred_ys = target(data_xs, params); - l2_norm(&pred_ys, data_ys) + l2_norm_2(&pred_ys, &data_ys) } -trait One { - const ONE: Self; -} - -impl One for f64 { - const ONE: f64 = 1.0; -} - -fn predict_line(xs: &tensor!(A, N), theta: &tensor!(A, 2)) -> tensor!(A, N) +fn predict_line_2(xs: Differentiable, theta: Differentiable) -> Differentiable where - A: Mul + Sum<::Output> + Copy + Default + Extensible2 + One, + A: Mul + Add + Sum<::Output> + Copy + Default + One + Zero, { - let mut result: tensor!(A, N) = [Default::default(); N]; - for (i, &x) in xs.iter().enumerate() { - result[i] = dot(&[x, One::ONE], &[*theta])[0]; + let xs = Differentiable::to_vector(xs) + .into_iter() + .map(|v| to_scalar(v)); + let mut result = vec![]; + for x in xs { + let left_arg = Differentiable::of_vector(vec![ + of_scalar(x.clone()), + of_scalar( as One>::one()), + ]); + let dotted = Differentiable::to_vector(dot_2(&left_arg, &theta)); + result.push(dotted[0].clone()); } - result + Differentiable::of_vector(result) } fn square(x: &A) -> A where - A: Mul + Clone + std::fmt::Display, + A: Mul + Clone, { - println!("{}", x); x.clone() * x.clone() } @@ -100,61 +97,16 @@ fn main() { ); println!("{:?}", loss); - let input_vec = Differentiable::Vector(Box::new([Differentiable::Scalar(Scalar::Number( - NotNan::new(27.0).expect("not nan"), - ))])); + let loss = l2_loss_2( + predict_line_2, + of_slice(&[2.0, 1.0, 4.0, 3.0]), + of_slice(&[1.8, 1.2, 4.2, 3.3]), + of_slice(&[0.0099, 0.0]), + ); + println!("{}", loss); - let grad = Differentiable::grad(|x| x.map(&|x| square(&x)), input_vec); + let input_vec = of_slice(&[NotNan::new(27.0).expect("not nan")]); + + let grad = Differentiable::grad(|x| Differentiable::map(x, &|x| square(&x)), input_vec); println!("{}", grad); } - -#[cfg(test)] -mod tests { - use super::*; - use little_learner::tensor::extension1; - - #[test] - fn test_extension() { - let x: tensor!(u8, 1) = [2]; - assert_eq!(extension1(&x, &7, |x, y| x + y), [9]); - let y: tensor!(u8, 1) = [7]; - assert_eq!(extension2(&x, &y, |x, y| x + y), [9]); - - let x: tensor!(u8, 3) = [5, 6, 7]; - assert_eq!(extension1(&x, &2, |x, y| x + y), [7, 8, 9]); - let y: tensor!(u8, 3) = [2, 0, 1]; - assert_eq!(extension2(&x, &y, |x, y| x + y), [7, 6, 8]); - - let x: tensor!(u8, 2, 3) = [[4, 6, 7], [2, 0, 1]]; - assert_eq!(extension1(&x, &2, |x, y| x + y), [[6, 8, 9], [4, 2, 3]]); - let y: tensor!(u8, 2, 3) = [[1, 2, 2], [6, 3, 1]]; - assert_eq!(extension2(&x, &y, |x, y| x + y), [[5, 8, 9], [8, 3, 2]]); - } - - #[test] - fn test_l2_norm() { - assert_eq!( - l2_norm(&[4.0, -3.0, 0.0, -4.0, 3.0], &[0.0, 0.0, 0.0, 0.0, 0.0]), - 50.0 - ) - } - - #[test] - fn test_l2_loss() { - let loss = l2_loss( - predict_line, - &[2.0, 1.0, 4.0, 3.0], - &[1.8, 1.2, 4.2, 3.3], - &[0.0, 0.0], - ); - assert_eq!(loss, 33.21); - - let loss = l2_loss( - predict_line, - &[2.0, 1.0, 4.0, 3.0], - &[1.8, 1.2, 4.2, 3.3], - &[0.0099, 0.0], - ); - assert_eq!((100.0 * loss).round() / 100.0, 32.59); - } -} diff --git a/little_learner_app/src/with_tensor.rs b/little_learner_app/src/with_tensor.rs new file mode 100644 index 0000000..9996733 --- /dev/null +++ b/little_learner_app/src/with_tensor.rs @@ -0,0 +1,126 @@ +use std::iter::Sum; +use std::ops::{Mul, Sub}; + +use little_learner::tensor; +use little_learner::tensor::{extension2, Extensible2}; +use little_learner::traits::One; + +type Point = [A; N]; + +type Parameters = [Point; M]; + +fn dot_points(x: &Point, y: &Point) -> A +where + A: Sum<::Output> + Copy + Default + Mul + Extensible2, +{ + extension2(x, y, |&x, &y| x * y).into_iter().sum() +} + +fn dot(x: &Point, y: &Parameters) -> Point +where + A: Mul + Sum<::Output> + Copy + Default + Extensible2, +{ + let mut result = [Default::default(); M]; + for (i, coord) in y.iter().map(|y| dot_points(x, y)).enumerate() { + result[i] = coord; + } + result +} + +fn sum(x: &tensor!(A, N)) -> A +where + A: Sum + Copy, +{ + A::sum(x.iter().cloned()) +} + +fn squared(x: &tensor!(A, N)) -> tensor!(A, N) +where + A: Mul + Extensible2 + Copy + Default, +{ + extension2(x, x, |&a, &b| (a * b)) +} + +fn l2_norm(prediction: &tensor!(A, N), data: &tensor!(A, N)) -> A +where + A: Sum + Mul + Extensible2 + Copy + Default + Sub, +{ + let diff = extension2(prediction, data, |&x, &y| x - y); + sum(&squared(&diff)) +} + +pub fn l2_loss( + target: F, + data_xs: &tensor!(A, N), + data_ys: &tensor!(A, N), + params: &Params, +) -> A +where + F: Fn(&tensor!(A, N), &Params) -> tensor!(A, N), + A: Sum + Mul + Extensible2 + Copy + Default + Sub, +{ + let pred_ys = target(data_xs, params); + l2_norm(&pred_ys, data_ys) +} + +pub fn predict_line(xs: &tensor!(A, N), theta: &tensor!(A, 2)) -> tensor!(A, N) +where + A: Mul + Sum<::Output> + Copy + Default + Extensible2 + One, +{ + let mut result: tensor!(A, N) = [Default::default(); N]; + for (i, &x) in xs.iter().enumerate() { + result[i] = dot(&[x, One::one()], &[*theta])[0]; + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + use little_learner::tensor::extension1; + + #[test] + fn test_extension() { + let x: tensor!(u8, 1) = [2]; + assert_eq!(extension1(&x, &7, |x, y| x + y), [9]); + let y: tensor!(u8, 1) = [7]; + assert_eq!(extension2(&x, &y, |x, y| x + y), [9]); + + let x: tensor!(u8, 3) = [5, 6, 7]; + assert_eq!(extension1(&x, &2, |x, y| x + y), [7, 8, 9]); + let y: tensor!(u8, 3) = [2, 0, 1]; + assert_eq!(extension2(&x, &y, |x, y| x + y), [7, 6, 8]); + + let x: tensor!(u8, 2, 3) = [[4, 6, 7], [2, 0, 1]]; + assert_eq!(extension1(&x, &2, |x, y| x + y), [[6, 8, 9], [4, 2, 3]]); + let y: tensor!(u8, 2, 3) = [[1, 2, 2], [6, 3, 1]]; + assert_eq!(extension2(&x, &y, |x, y| x + y), [[5, 8, 9], [8, 3, 2]]); + } + + #[test] + fn test_l2_norm() { + assert_eq!( + l2_norm(&[4.0, -3.0, 0.0, -4.0, 3.0], &[0.0, 0.0, 0.0, 0.0, 0.0]), + 50.0 + ) + } + + #[test] + fn test_l2_loss() { + let loss = l2_loss( + predict_line, + &[2.0, 1.0, 4.0, 3.0], + &[1.8, 1.2, 4.2, 3.3], + &[0.0, 0.0], + ); + assert_eq!(loss, 33.21); + + let loss = l2_loss( + predict_line, + &[2.0, 1.0, 4.0, 3.0], + &[1.8, 1.2, 4.2, 3.3], + &[0.0099, 0.0], + ); + assert_eq!((100.0 * loss).round() / 100.0, 32.59); + } +} diff --git a/rust-toolchain b/rust-toolchain new file mode 100644 index 0000000..bf867e0 --- /dev/null +++ b/rust-toolchain @@ -0,0 +1 @@ +nightly