diff --git a/Cargo.lock b/Cargo.lock index e910ad4..bb58012 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,121 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "arrayvec" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" + +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + +[[package]] +name = "immutable-chunkmap" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7617eb072b88069788fa9d5cadae34faebca64e5325ec5deaa2b4c96510f9e8c" +dependencies = [ + "arrayvec", + "packed_struct", + "packed_struct_codegen", +] + [[package]] name = "little_learner" version = "0.1.0" +dependencies = [ + "immutable-chunkmap", +] + +[[package]] +name = "packed_struct" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36b29691432cc9eff8b282278473b63df73bea49bc3ec5e67f31a3ae9c3ec190" +dependencies = [ + "bitvec", + "packed_struct_codegen", +] + +[[package]] +name = "packed_struct_codegen" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cd6706dfe50d53e0f6aa09e12c034c44faacd23e966ae5a209e8bdb8f179f98" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + +[[package]] +name = "unicode-ident" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" + +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] diff --git a/Cargo.toml b/Cargo.toml index babbc5e..a2558bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,3 +6,4 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +immutable-chunkmap = "1.0.5" diff --git a/src/expr_syntax_tree.rs b/src/expr_syntax_tree.rs new file mode 100644 index 0000000..7c9d2b1 --- /dev/null +++ b/src/expr_syntax_tree.rs @@ -0,0 +1,157 @@ +#![allow(dead_code)] + +use immutable_chunkmap::map; +use std::ops::{Add, Mul}; + +/* +An untyped syntax tree for an expression whose constants are all of type `A`. +*/ +#[derive(Clone, Debug)] +pub enum Expr { + Const(A), + Sum(Box>, Box>), + Variable(u32), + // The first `Expr` here is a function, which may reference the input variable `Variable(i)`. + // For example, `(fun x y -> x + y) 3 4` is expressed as: + // Apply(0, Apply(1, Sum(Variable(0), Variable(1)), Const(4)), Const(3)) + Apply(u32, Box>, Box>), + Mul(Box>, Box>), +} + +impl Expr { + fn eval_inner(e: &Expr, ctx: &map::Map) -> A + where + A: Clone + Add + Mul, + { + match &e { + Expr::Const(x) => x.clone(), + Expr::Sum(x, y) => Expr::eval_inner(x, ctx) + Expr::eval_inner(y, ctx), + Expr::Variable(id) => ctx + .get(id) + .unwrap_or_else(|| panic!("No binding found for free variable {}", id)) + .clone(), + Expr::Apply(variable, func, arg) => { + let arg = Expr::eval_inner(arg, ctx); + let (updated_context, _) = ctx.insert(*variable, arg); + Expr::eval_inner(func, &updated_context) + } + Expr::Mul(x, y) => Expr::eval_inner(x, ctx) * Expr::eval_inner(y, ctx), + } + } + + pub fn eval(e: &Expr) -> A + where + A: Clone + Add + Mul, + { + Expr::eval_inner(e, &map::Map::::new()) + } + + pub fn apply(var: u32, f: Expr, arg: Expr) -> Expr { + Expr::Apply(var, Box::new(f), Box::new(arg)) + } + + pub fn sum(x: Expr, y: Expr) -> Expr { + Expr::Sum(Box::new(x), Box::new(y)) + } + + pub fn mul(x: Expr, y: Expr) -> Expr { + Expr::Mul(Box::new(x), Box::new(y)) + } + + pub fn differentiate(one: &A, zero: &A, var: u32, f: &Expr) -> Expr + where + A: Clone, + { + match f { + Expr::Const(_) => Expr::Const(zero.clone()), + Expr::Sum(x, y) => Expr::sum( + Expr::differentiate(one, zero, var, x), + Expr::differentiate(one, zero, var, y), + ), + Expr::Variable(i) => { + if *i == var { + Expr::Const(one.clone()) + } else { + Expr::Const(zero.clone()) + } + } + Expr::Mul(x, y) => Expr::sum( + Expr::Mul( + Box::new(Expr::differentiate(one, zero, var, x.as_ref())), + (*y).clone(), + ), + Expr::Mul( + Box::new(Expr::differentiate(one, zero, var, y.as_ref())), + (*x).clone(), + ), + ), + Expr::Apply(new_var, func, expr) => { + if *new_var == var { + panic!( + "cannot differentiate with respect to variable {} that's been assigned", + var + ) + } + let expr_deriv = Expr::differentiate(one, zero, var, expr); + Expr::mul( + expr_deriv, + Expr::Apply( + *new_var, + Box::new(Expr::differentiate(one, zero, *new_var, func)), + (*expr).clone(), + ), + ) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_expr() { + let expr = Expr::apply( + 0, + Expr::apply( + 1, + Expr::sum(Expr::Variable(0), Expr::Variable(1)), + Expr::Const(4), + ), + Expr::Const(3), + ); + + assert_eq!(Expr::eval::<2>(&expr), 7); + } + + #[test] + fn test_derivative() { + let add_four = Expr::sum(Expr::Variable(0), Expr::Const(4)); + let mul_five = Expr::mul(Expr::Variable(1), Expr::Const(5)); + + { + let mul_five_then_add_four = Expr::apply(0, add_four.clone(), mul_five.clone()); + let mul_then_add_diff = Expr::differentiate(&1, &0, 1, &mul_five_then_add_four); + for i in 3..10 { + // (5x + 4) differentiates to 5 + assert_eq!( + Expr::eval::<2>(&Expr::apply(1, mul_then_add_diff.clone(), Expr::Const(i))), + 5 + ); + } + } + + { + let add_four_then_mul_five = Expr::apply(1, mul_five.clone(), add_four.clone()); + let add_then_mul_diff = Expr::differentiate(&1, &0, 0, &add_four_then_mul_five); + for i in 3..10 { + // ((x + 4) * 5) differentiates to 5 + assert_eq!( + Expr::eval::<2>(&Expr::apply(0, add_then_mul_diff.clone(), Expr::Const(i))), + 5 + ); + } + } + } +} diff --git a/src/main.rs b/src/main.rs index e7101c7..a60defe 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +mod expr_syntax_tree; + use std::iter::Sum; use std::ops::{Mul, Sub};