From 32caf8d7d6f968c8d69d60506d97da13c3373829 Mon Sep 17 00:00:00 2001 From: Patrick Stevens Date: Sat, 25 Mar 2023 22:19:04 +0000 Subject: [PATCH] Automatic differentiation (#5) --- .gitignore | 3 +- Cargo.lock | 34 ++ Cargo.toml | 14 +- flake.lock | 12 +- flake.nix | 4 +- little_learner/Cargo.lock | 147 +++++++ little_learner/Cargo.toml | 12 + little_learner/src/auto_diff.rs | 365 ++++++++++++++++++ .../src}/expr_syntax_tree.rs | 50 ++- little_learner/src/lib.rs | 3 + little_learner/src/tensor.rs | 107 +++++ little_learner_app/Cargo.lock | 156 ++++++++ little_learner_app/Cargo.toml | 11 + {src => little_learner_app/src}/main.rs | 126 +----- 14 files changed, 894 insertions(+), 150 deletions(-) create mode 100644 little_learner/Cargo.lock create mode 100644 little_learner/Cargo.toml create mode 100644 little_learner/src/auto_diff.rs rename {src => little_learner/src}/expr_syntax_tree.rs (84%) create mode 100644 little_learner/src/lib.rs create mode 100644 little_learner/src/tensor.rs create mode 100644 little_learner_app/Cargo.lock create mode 100644 little_learner_app/Cargo.toml rename {src => little_learner_app/src}/main.rs (62%) diff --git a/.gitignore b/.gitignore index 7b65c95..d175129 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ -/target +target/ .idea/ *.iml +.vscode/ diff --git a/Cargo.lock b/Cargo.lock index bb58012..c30cc64 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,12 @@ version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + [[package]] name = "bitvec" version = "1.0.1" @@ -42,6 +48,34 @@ name = "little_learner" version = "0.1.0" dependencies = [ "immutable-chunkmap", + "ordered-float", +] + +[[package]] +name = "little_learner_app" +version = "0.1.0" +dependencies = [ + "immutable-chunkmap", + "little_learner", + "ordered-float", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ordered-float" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13a384337e997e6860ffbaa83708b2ef329fd8c54cb67a5f64d421e0f943254f" +dependencies = [ + "num-traits", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a2558bf..4b703c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,9 +1,5 @@ -[package] -name = "little_learner" -version = "0.1.0" -edition = "2021" - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -immutable-chunkmap = "1.0.5" +[workspace] +members = [ + "little_learner", + "little_learner_app" +] diff --git a/flake.lock b/flake.lock index 097af12..f3f7ab2 100644 --- a/flake.lock +++ b/flake.lock @@ -49,11 +49,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1678898370, - "narHash": "sha256-xTICr1j+uat5hk9FyuPOFGxpWHdJRibwZC+ATi0RbtE=", + "lastModified": 1679705136, + "narHash": "sha256-MDlZUR7wJ3PlPtqwwoGQr3euNOe0vdSSteVVOef7tBY=", "owner": "nixos", "repo": "nixpkgs", - "rev": "ac718d02867a84b42522a0ece52d841188208f2c", + "rev": "8f40f2f90b9c9032d1b824442cfbbe0dbabd0dbd", "type": "github" }, "original": { @@ -65,11 +65,11 @@ }, "nixpkgs_2": { "locked": { - "lastModified": 1665296151, - "narHash": "sha256-uOB0oxqxN9K7XGF1hcnY+PQnlQJ+3bP2vCn/+Ru/bbc=", + "lastModified": 1679734080, + "narHash": "sha256-z846xfGLlon6t9lqUzlNtBOmsgQLQIZvR6Lt2dImk1M=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "14ccaaedd95a488dd7ae142757884d8e125b3363", + "rev": "dbf5322e93bcc6cfc52268367a8ad21c09d76fea", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 90c290d..ccb3806 100644 --- a/flake.nix +++ b/flake.nix @@ -23,7 +23,7 @@ crate2nix, ... }: let - name = "little_learner"; + name = "little_learner_app"; in utils.lib.eachDefaultSystem ( @@ -77,7 +77,7 @@ PKG_CONFIG_PATH = "${pkgs.openssl.dev}/lib/pkgconfig"; }; in rec { - packages.${name} = project.rootCrate.build; + packages.${name} = project.workspaceMembers.${name}.build; # `nix build` defaultPackage = packages.${name}; diff --git a/little_learner/Cargo.lock b/little_learner/Cargo.lock new file mode 100644 index 0000000..4b27677 --- /dev/null +++ b/little_learner/Cargo.lock @@ -0,0 +1,147 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "arrayvec" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "immutable-chunkmap" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7617eb072b88069788fa9d5cadae34faebca64e5325ec5deaa2b4c96510f9e8c" +dependencies = [ + "arrayvec", + "packed_struct", + "packed_struct_codegen", +] + +[[package]] +name = "little_learner" +version = "0.1.0" +dependencies = [ + "immutable-chunkmap", + "ordered-float", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ordered-float" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13a384337e997e6860ffbaa83708b2ef329fd8c54cb67a5f64d421e0f943254f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "packed_struct" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36b29691432cc9eff8b282278473b63df73bea49bc3ec5e67f31a3ae9c3ec190" +dependencies = [ + "bitvec", + "packed_struct_codegen", +] + +[[package]] +name = "packed_struct_codegen" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cd6706dfe50d53e0f6aa09e12c034c44faacd23e966ae5a209e8bdb8f179f98" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "unicode-ident" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] diff --git a/little_learner/Cargo.toml b/little_learner/Cargo.toml new file mode 100644 index 0000000..b32090b --- /dev/null +++ b/little_learner/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "little_learner" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +immutable-chunkmap = "1.0.5" +ordered-float = "3.6.0" + +[lib] diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs new file mode 100644 index 0000000..ed4eac9 --- /dev/null +++ b/little_learner/src/auto_diff.rs @@ -0,0 +1,365 @@ +use core::hash::Hash; +use ordered_float::NotNan; +use std::{ + collections::{hash_map::Entry, HashMap}, + fmt::{Display, Write}, + ops::{Add, AddAssign, Div, Mul}, +}; + +pub trait Zero { + fn zero() -> Self; +} + +pub trait One { + fn one() -> Self; +} + +impl Zero for f64 { + fn zero() -> Self { + 0.0 + } +} + +impl One for f64 { + fn one() -> Self { + 1.0 + } +} + +impl Zero for NotNan { + fn zero() -> Self { + NotNan::new(0.0).unwrap() + } +} + +impl One for NotNan { + fn one() -> Self { + NotNan::new(1.0).unwrap() + } +} + +impl Zero for Differentiable +where + A: Zero, +{ + fn zero() -> Differentiable { + Differentiable::Scalar(Scalar::Number(A::zero())) + } +} + +impl One for Differentiable +where + A: One, +{ + fn one() -> Differentiable { + Differentiable::Scalar(Scalar::Number(A::one())) + } +} + +pub trait Exp { + fn exp(self) -> Self; +} + +impl Exp for NotNan { + fn exp(self) -> Self { + NotNan::new(f64::exp(self.into_inner())).expect("expected a non-NaN") + } +} + +#[derive(Clone, Hash, PartialEq, Eq)] +pub enum LinkData { + Addition(Box>, Box>), + Mul(Box>, Box>), + Exponent(Box>), + Log(Box>), +} + +#[derive(Clone, Hash, PartialEq, Eq)] +pub enum Link { + EndOfLink, + Link(LinkData), +} + +impl Display for Link +where + A: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Link::EndOfLink => f.write_str(""), + Link::Link(LinkData::Addition(left, right)) => { + f.write_fmt(format_args!("({} + {})", left.as_ref(), right.as_ref())) + } + Link::Link(LinkData::Mul(left, right)) => { + f.write_fmt(format_args!("({} * {})", left.as_ref(), right.as_ref())) + } + Link::Link(LinkData::Exponent(arg)) => { + f.write_fmt(format_args!("exp({})", arg.as_ref())) + } + Link::Link(LinkData::Log(arg)) => f.write_fmt(format_args!("log({})", arg.as_ref())), + } + } +} + +impl Link { + fn invoke(self, d: &Scalar, z: A, acc: &mut HashMap, A>) + where + A: Eq + Hash + AddAssign + Clone + Exp + Mul + Div + Zero + One, + { + match self { + Link::EndOfLink => match acc.entry(d.clone()) { + Entry::Occupied(mut o) => { + let entry = o.get_mut(); + *entry += z; + } + Entry::Vacant(v) => { + v.insert(z); + } + }, + Link::Link(data) => { + match data { + LinkData::Addition(left, right) => { + // The `z` here reflects the fact that dx/dx = 1, so it's 1 * z. + left.as_ref().clone_link().invoke(&left, z.clone(), acc); + right.as_ref().clone_link().invoke(&right, z, acc); + } + LinkData::Exponent(arg) => { + // d/dx (e^x) = exp x, so exp z * z. + arg.as_ref().clone_link().invoke( + &arg, + z * arg.clone_real_part().exp(), + acc, + ); + } + LinkData::Mul(left, right) => { + // d/dx(f g) = f dg/dx + g df/dx + left.as_ref().clone_link().invoke( + &left, + right.clone_real_part() * z.clone(), + acc, + ); + right + .as_ref() + .clone_link() + .invoke(&right, left.clone_real_part() * z, acc); + } + LinkData::Log(arg) => { + // d/dx(log y) = 1/y dy/dx + arg.as_ref().clone_link().invoke( + &arg, + A::one() / arg.clone_real_part() * z, + acc, + ); + } + } + } + } + } +} + +#[derive(Clone, Hash, PartialEq, Eq)] +pub enum Scalar { + Number(A), + // The value, and the link. + Dual(A, Link), +} + +impl Add for Scalar +where + A: Add + Clone, +{ + type Output = Scalar; + + fn add(self, rhs: Self) -> Self::Output { + Scalar::Dual( + self.clone_real_part() + rhs.clone_real_part(), + Link::Link(LinkData::Addition(Box::new(self), Box::new(rhs))), + ) + } +} + +impl Mul for Scalar +where + A: Mul + Clone, +{ + type Output = Scalar; + + fn mul(self, rhs: Self) -> Self::Output { + Scalar::Dual( + self.clone_real_part() * rhs.clone_real_part(), + Link::Link(LinkData::Mul(Box::new(self), Box::new(rhs))), + ) + } +} + +impl Scalar { + pub fn real_part(&self) -> &A { + match self { + Scalar::Number(a) => a, + Scalar::Dual(a, _) => a, + } + } + + fn clone_real_part(&self) -> A + where + A: Clone, + { + match self { + Scalar::Number(a) => (*a).clone(), + Scalar::Dual(a, _) => (*a).clone(), + } + } + + pub fn link(self) -> Link { + match self { + Scalar::Dual(_, link) => link, + Scalar::Number(_) => Link::EndOfLink, + } + } + + fn clone_link(&self) -> Link + where + A: Clone, + { + match self { + Scalar::Dual(_, data) => data.clone(), + Scalar::Number(_) => Link::EndOfLink, + } + } + + fn truncate_dual(self) -> Scalar + where + A: Clone, + { + Scalar::Dual(self.clone_real_part(), Link::EndOfLink) + } +} + +impl Display for Scalar +where + A: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Scalar::Number(n) => f.write_fmt(format_args!("{}", n)), + Scalar::Dual(n, link) => f.write_fmt(format_args!("{}, link: {}", n, link)), + } + } +} + +pub enum Differentiable { + Scalar(Scalar), + Vector(Box<[Differentiable]>), +} + +impl Display for Differentiable +where + A: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)), + Differentiable::Vector(v) => { + f.write_char('[')?; + for v in v.iter() { + f.write_fmt(format_args!("{}", v))?; + f.write_char(',')?; + } + f.write_char(']') + } + } + } +} + +impl Differentiable { + pub fn map(&self, f: &F) -> Differentiable + where + F: Fn(Scalar) -> Scalar, + A: Clone, + { + match self { + Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())), + Differentiable::Vector(slice) => { + Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect()) + } + } + } +} + +impl Differentiable +where + A: Clone + Eq + Hash + AddAssign + Mul + Exp + Div + Zero + One, +{ + fn accumulate_gradients_vec(v: &[Differentiable], acc: &mut HashMap, A>) { + for v in v.iter().rev() { + v.accumulate_gradients(acc); + } + } + + fn accumulate_gradients(&self, acc: &mut HashMap, A>) { + match self { + Differentiable::Scalar(y) => { + let k = y.clone_link(); + k.invoke(y, A::one(), acc); + } + Differentiable::Vector(y) => Differentiable::accumulate_gradients_vec(y, acc), + } + } + + fn grad_once(self, wrt: Differentiable) -> Differentiable { + let mut acc = HashMap::new(); + self.accumulate_gradients(&mut acc); + + wrt.map(&|d| match acc.get(&d) { + None => Scalar::Number(A::zero()), + Some(x) => Scalar::Number(x.clone()), + }) + } + + pub fn grad(f: F, theta: Differentiable) -> Differentiable + where + F: Fn(&Differentiable) -> Differentiable, + { + let wrt = theta.map(&Scalar::truncate_dual); + let after_f = f(&wrt); + Differentiable::grad_once(after_f, wrt) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn extract_scalar<'a, A>(d: &'a Differentiable) -> &'a A { + match d { + Differentiable::Scalar(a) => &(a.real_part()), + Differentiable::Vector(_) => panic!("not a scalar"), + } + } + + #[test] + fn test_map() { + let v = Differentiable::Vector( + vec![ + Differentiable::Scalar(Scalar::Number(NotNan::new(3.0).expect("3 is not NaN"))), + Differentiable::Scalar(Scalar::Number(NotNan::new(4.0).expect("4 is not NaN"))), + ] + .into(), + ); + let mapped = v.map(&|x: Scalar>| match x { + Scalar::Number(i) => Scalar::Number(i + NotNan::new(1.0).expect("1 is not NaN")), + Scalar::Dual(_, _) => panic!("Not hit"), + }); + + let v = match mapped { + Differentiable::Scalar(_) => panic!("Not a scalar"), + Differentiable::Vector(v) => v + .as_ref() + .iter() + .map(|d| extract_scalar(d).clone()) + .collect::>(), + }; + + assert_eq!(v, [4.0, 5.0]); + } +} diff --git a/src/expr_syntax_tree.rs b/little_learner/src/expr_syntax_tree.rs similarity index 84% rename from src/expr_syntax_tree.rs rename to little_learner/src/expr_syntax_tree.rs index 7c9d2b1..435af58 100644 --- a/src/expr_syntax_tree.rs +++ b/little_learner/src/expr_syntax_tree.rs @@ -1,5 +1,3 @@ -#![allow(dead_code)] - use immutable_chunkmap::map; use std::ops::{Add, Mul}; @@ -50,24 +48,15 @@ impl Expr { Expr::Apply(var, Box::new(f), Box::new(arg)) } - pub fn sum(x: Expr, y: Expr) -> Expr { - Expr::Sum(Box::new(x), Box::new(y)) - } - - pub fn mul(x: Expr, y: Expr) -> Expr { - Expr::Mul(Box::new(x), Box::new(y)) - } - pub fn differentiate(one: &A, zero: &A, var: u32, f: &Expr) -> Expr where A: Clone, { match f { Expr::Const(_) => Expr::Const(zero.clone()), - Expr::Sum(x, y) => Expr::sum( - Expr::differentiate(one, zero, var, x), - Expr::differentiate(one, zero, var, y), - ), + Expr::Sum(x, y) => { + Expr::differentiate(one, zero, var, x) + Expr::differentiate(one, zero, var, y) + } Expr::Variable(i) => { if *i == var { Expr::Const(one.clone()) @@ -75,16 +64,15 @@ impl Expr { Expr::Const(zero.clone()) } } - Expr::Mul(x, y) => Expr::sum( + Expr::Mul(x, y) => { Expr::Mul( Box::new(Expr::differentiate(one, zero, var, x.as_ref())), (*y).clone(), - ), - Expr::Mul( + ) + Expr::Mul( Box::new(Expr::differentiate(one, zero, var, y.as_ref())), (*x).clone(), - ), - ), + ) + } Expr::Apply(new_var, func, expr) => { if *new_var == var { panic!( @@ -106,6 +94,20 @@ impl Expr { } } +impl Add for Expr { + type Output = Expr; + fn add(self: Expr, y: Expr) -> Expr { + Expr::Sum(Box::new(self), Box::new(y)) + } +} + +impl Mul for Expr { + type Output = Expr; + fn mul(self: Expr, y: Expr) -> Expr { + Expr::Mul(Box::new(self), Box::new(y)) + } +} + #[cfg(test)] mod tests { use super::*; @@ -114,11 +116,7 @@ mod tests { fn test_expr() { let expr = Expr::apply( 0, - Expr::apply( - 1, - Expr::sum(Expr::Variable(0), Expr::Variable(1)), - Expr::Const(4), - ), + Expr::apply(1, Expr::Variable(0) + Expr::Variable(1), Expr::Const(4)), Expr::Const(3), ); @@ -127,8 +125,8 @@ mod tests { #[test] fn test_derivative() { - let add_four = Expr::sum(Expr::Variable(0), Expr::Const(4)); - let mul_five = Expr::mul(Expr::Variable(1), Expr::Const(5)); + let add_four = Expr::Variable(0) + Expr::Const(4); + let mul_five = Expr::Variable(1) * Expr::Const(5); { let mul_five_then_add_four = Expr::apply(0, add_four.clone(), mul_five.clone()); diff --git a/little_learner/src/lib.rs b/little_learner/src/lib.rs new file mode 100644 index 0000000..7590a38 --- /dev/null +++ b/little_learner/src/lib.rs @@ -0,0 +1,3 @@ +pub mod auto_diff; +pub mod expr_syntax_tree; +pub mod tensor; diff --git a/little_learner/src/tensor.rs b/little_learner/src/tensor.rs new file mode 100644 index 0000000..c769b40 --- /dev/null +++ b/little_learner/src/tensor.rs @@ -0,0 +1,107 @@ +#[macro_export] +macro_rules! tensor { + ($x:ty , $i: expr) => {[$x; $i]}; + ($x:ty , $i: expr, $($is:expr),+) => {[tensor!($x, $($is),+); $i]}; +} + +#[cfg(test)] +mod tests { + #[test] + fn test_tensor_type() { + let _: tensor!(f64, 1, 2, 3) = [[[1.0, 3.0, 6.0], [-1.3, -30.0, -0.0]]]; + } +} + +pub trait Extensible1 { + fn apply(&self, other: &A, op: &F) -> Self + where + F: Fn(&A, &A) -> A; +} + +pub trait Extensible2 { + fn apply(&self, other: &Self, op: &F) -> Self + where + F: Fn(&A, &A) -> A; +} + +impl Extensible1 for [T; N] +where + T: Extensible1 + Copy + Default, +{ + fn apply(&self, other: &A, op: &F) -> Self + where + F: Fn(&A, &A) -> A, + { + let mut result = [Default::default(); N]; + for (i, coord) in self.iter().enumerate() { + result[i] = T::apply(coord, other, op); + } + result + } +} + +impl Extensible2 for [T; N] +where + T: Extensible2 + Copy + Default, +{ + fn apply(&self, other: &Self, op: &F) -> Self + where + F: Fn(&A, &A) -> A, + { + let mut result = [Default::default(); N]; + for (i, coord) in self.iter().enumerate() { + result[i] = T::apply(coord, &other[i], op); + } + result + } +} + +#[macro_export] +macro_rules! extensible1 { + ($x: ty) => { + impl Extensible1<$x> for $x { + fn apply(&self, other: &$x, op: &F) -> Self + where + F: Fn(&Self, &Self) -> Self, + { + op(self, other) + } + } + }; +} + +#[macro_export] +macro_rules! extensible2 { + ($x: ty) => { + impl Extensible2<$x> for $x { + fn apply(&self, other: &Self, op: &F) -> Self + where + F: Fn(&Self, &Self) -> Self, + { + op(self, other) + } + } + }; +} + +extensible1!(u8); +extensible1!(f64); + +extensible2!(u8); +extensible2!(f64); + +pub fn extension1(t1: &T, t2: &A, op: F) -> T +where + T: Extensible1, + F: Fn(&A, &A) -> A, +{ + t1.apply::(t2, &op) +} + +pub fn extension2(t1: &T, t2: &T, op: F) -> T +where + T: Extensible2, + F: Fn(&A, &A) -> A, +{ + t1.apply::(t2, &op) +} diff --git a/little_learner_app/Cargo.lock b/little_learner_app/Cargo.lock new file mode 100644 index 0000000..c30cc64 --- /dev/null +++ b/little_learner_app/Cargo.lock @@ -0,0 +1,156 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "arrayvec" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "immutable-chunkmap" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7617eb072b88069788fa9d5cadae34faebca64e5325ec5deaa2b4c96510f9e8c" +dependencies = [ + "arrayvec", + "packed_struct", + "packed_struct_codegen", +] + +[[package]] +name = "little_learner" +version = "0.1.0" +dependencies = [ + "immutable-chunkmap", + "ordered-float", +] + +[[package]] +name = "little_learner_app" +version = "0.1.0" +dependencies = [ + "immutable-chunkmap", + "little_learner", + "ordered-float", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ordered-float" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13a384337e997e6860ffbaa83708b2ef329fd8c54cb67a5f64d421e0f943254f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "packed_struct" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36b29691432cc9eff8b282278473b63df73bea49bc3ec5e67f31a3ae9c3ec190" +dependencies = [ + "bitvec", + "packed_struct_codegen", +] + +[[package]] +name = "packed_struct_codegen" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cd6706dfe50d53e0f6aa09e12c034c44faacd23e966ae5a209e8bdb8f179f98" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "unicode-ident" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] diff --git a/little_learner_app/Cargo.toml b/little_learner_app/Cargo.toml new file mode 100644 index 0000000..2324893 --- /dev/null +++ b/little_learner_app/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "little_learner_app" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +immutable-chunkmap = "1.0.5" +ordered-float = "3.6.0" +little_learner = { path = "../little_learner" } \ No newline at end of file diff --git a/src/main.rs b/little_learner_app/src/main.rs similarity index 62% rename from src/main.rs rename to little_learner_app/src/main.rs index a60defe..5fcba56 100644 --- a/src/main.rs +++ b/little_learner_app/src/main.rs @@ -1,4 +1,7 @@ -mod expr_syntax_tree; +use little_learner::auto_diff::{Differentiable, Scalar}; +use little_learner::tensor; +use little_learner::tensor::{extension2, Extensible2}; +use ordered_float::NotNan; use std::iter::Sum; use std::ops::{Mul, Sub}; @@ -7,106 +10,6 @@ type Point = [A; N]; type Parameters = [Point; M]; -#[macro_export] -macro_rules! tensor { - ($x:ty , $i: expr) => {[$x; $i]}; - ($x:ty , $i: expr, $($is:expr),+) => {[tensor!($x, $($is),+); $i]}; -} - -pub trait Extensible1 { - fn apply(&self, other: &A, op: &F) -> Self - where - F: Fn(&A, &A) -> A; -} - -pub trait Extensible2 { - fn apply(&self, other: &Self, op: &F) -> Self - where - F: Fn(&A, &A) -> A; -} - -impl Extensible1 for [T; N] -where - T: Extensible1 + Copy + Default, -{ - fn apply(&self, other: &A, op: &F) -> Self - where - F: Fn(&A, &A) -> A, - { - let mut result = [Default::default(); N]; - for (i, coord) in self.iter().enumerate() { - result[i] = T::apply(coord, other, op); - } - result - } -} - -impl Extensible2 for [T; N] -where - T: Extensible2 + Copy + Default, -{ - fn apply(&self, other: &Self, op: &F) -> Self - where - F: Fn(&A, &A) -> A, - { - let mut result = [Default::default(); N]; - for (i, coord) in self.iter().enumerate() { - result[i] = T::apply(coord, &other[i], op); - } - result - } -} - -#[macro_export] -macro_rules! extensible1 { - ($x: ty) => { - impl Extensible1<$x> for $x { - fn apply(&self, other: &$x, op: &F) -> Self - where - F: Fn(&Self, &Self) -> Self, - { - op(self, other) - } - } - }; -} - -#[macro_export] -macro_rules! extensible2 { - ($x: ty) => { - impl Extensible2<$x> for $x { - fn apply(&self, other: &Self, op: &F) -> Self - where - F: Fn(&Self, &Self) -> Self, - { - op(self, other) - } - } - }; -} - -extensible1!(u8); -extensible1!(f64); - -extensible2!(u8); -extensible2!(f64); - -pub fn extension1(t1: &T, t2: &A, op: F) -> T -where - T: Extensible1, - F: Fn(&A, &A) -> A, -{ - t1.apply::(t2, &op) -} - -pub fn extension2(t1: &T, t2: &T, op: F) -> T -where - T: Extensible2, - F: Fn(&A, &A) -> A, -{ - t1.apply::(t2, &op) -} - fn dot_points(x: &Point, y: &Point) -> A where A: Sum<::Output> + Copy + Default + Mul + Extensible2, @@ -180,6 +83,14 @@ where result } +fn square(x: &A) -> A +where + A: Mul + Clone + std::fmt::Display, +{ + println!("{}", x); + x.clone() * x.clone() +} + fn main() { let loss = l2_loss( predict_line, @@ -188,16 +99,19 @@ fn main() { &[0.0099, 0.0], ); println!("{:?}", loss); + + let input_vec = Differentiable::Vector(Box::new([Differentiable::Scalar(Scalar::Number( + NotNan::new(27.0).expect("not nan"), + ))])); + + let grad = Differentiable::grad(|x| x.map(&|x| square(&x)), input_vec); + println!("{}", grad); } #[cfg(test)] mod tests { use super::*; - - #[test] - fn test_tensor_type() { - let _: tensor!(f64, 1, 2, 3) = [[[1.0, 3.0, 6.0], [-1.3, -30.0, -0.0]]]; - } + use little_learner::tensor::extension1; #[test] fn test_extension() {