Symbolic differentiation (#4)

This commit is contained in:
Patrick Stevens
2023-03-24 21:10:36 +00:00
committed by GitHub
parent c3bfeb0762
commit adff7ac3fd
4 changed files with 275 additions and 0 deletions

115
Cargo.lock generated
View File

@@ -2,6 +2,121 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 3
[[package]]
name = "arrayvec"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "immutable-chunkmap"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7617eb072b88069788fa9d5cadae34faebca64e5325ec5deaa2b4c96510f9e8c"
dependencies = [
"arrayvec",
"packed_struct",
"packed_struct_codegen",
]
[[package]] [[package]]
name = "little_learner" name = "little_learner"
version = "0.1.0" version = "0.1.0"
dependencies = [
"immutable-chunkmap",
]
[[package]]
name = "packed_struct"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36b29691432cc9eff8b282278473b63df73bea49bc3ec5e67f31a3ae9c3ec190"
dependencies = [
"bitvec",
"packed_struct_codegen",
]
[[package]]
name = "packed_struct_codegen"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cd6706dfe50d53e0f6aa09e12c034c44faacd23e966ae5a209e8bdb8f179f98"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc"
dependencies = [
"proc-macro2",
]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "unicode-ident"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4"
[[package]]
name = "wyz"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
dependencies = [
"tap",
]

View File

@@ -6,3 +6,4 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
immutable-chunkmap = "1.0.5"

157
src/expr_syntax_tree.rs Normal file
View File

@@ -0,0 +1,157 @@
#![allow(dead_code)]
use immutable_chunkmap::map;
use std::ops::{Add, Mul};
/*
An untyped syntax tree for an expression whose constants are all of type `A`.
*/
#[derive(Clone, Debug)]
pub enum Expr<A> {
Const(A),
Sum(Box<Expr<A>>, Box<Expr<A>>),
Variable(u32),
// The first `Expr` here is a function, which may reference the input variable `Variable(i)`.
// For example, `(fun x y -> x + y) 3 4` is expressed as:
// Apply(0, Apply(1, Sum(Variable(0), Variable(1)), Const(4)), Const(3))
Apply(u32, Box<Expr<A>>, Box<Expr<A>>),
Mul(Box<Expr<A>>, Box<Expr<A>>),
}
impl<A> Expr<A> {
fn eval_inner<const SIZE: usize>(e: &Expr<A>, ctx: &map::Map<u32, A, SIZE>) -> A
where
A: Clone + Add<Output = A> + Mul<Output = A>,
{
match &e {
Expr::Const(x) => x.clone(),
Expr::Sum(x, y) => Expr::eval_inner(x, ctx) + Expr::eval_inner(y, ctx),
Expr::Variable(id) => ctx
.get(id)
.unwrap_or_else(|| panic!("No binding found for free variable {}", id))
.clone(),
Expr::Apply(variable, func, arg) => {
let arg = Expr::eval_inner(arg, ctx);
let (updated_context, _) = ctx.insert(*variable, arg);
Expr::eval_inner(func, &updated_context)
}
Expr::Mul(x, y) => Expr::eval_inner(x, ctx) * Expr::eval_inner(y, ctx),
}
}
pub fn eval<const MAX_VAR_NUM: usize>(e: &Expr<A>) -> A
where
A: Clone + Add<Output = A> + Mul<Output = A>,
{
Expr::eval_inner(e, &map::Map::<u32, A, MAX_VAR_NUM>::new())
}
pub fn apply(var: u32, f: Expr<A>, arg: Expr<A>) -> Expr<A> {
Expr::Apply(var, Box::new(f), Box::new(arg))
}
pub fn sum(x: Expr<A>, y: Expr<A>) -> Expr<A> {
Expr::Sum(Box::new(x), Box::new(y))
}
pub fn mul(x: Expr<A>, y: Expr<A>) -> Expr<A> {
Expr::Mul(Box::new(x), Box::new(y))
}
pub fn differentiate(one: &A, zero: &A, var: u32, f: &Expr<A>) -> Expr<A>
where
A: Clone,
{
match f {
Expr::Const(_) => Expr::Const(zero.clone()),
Expr::Sum(x, y) => Expr::sum(
Expr::differentiate(one, zero, var, x),
Expr::differentiate(one, zero, var, y),
),
Expr::Variable(i) => {
if *i == var {
Expr::Const(one.clone())
} else {
Expr::Const(zero.clone())
}
}
Expr::Mul(x, y) => Expr::sum(
Expr::Mul(
Box::new(Expr::differentiate(one, zero, var, x.as_ref())),
(*y).clone(),
),
Expr::Mul(
Box::new(Expr::differentiate(one, zero, var, y.as_ref())),
(*x).clone(),
),
),
Expr::Apply(new_var, func, expr) => {
if *new_var == var {
panic!(
"cannot differentiate with respect to variable {} that's been assigned",
var
)
}
let expr_deriv = Expr::differentiate(one, zero, var, expr);
Expr::mul(
expr_deriv,
Expr::Apply(
*new_var,
Box::new(Expr::differentiate(one, zero, *new_var, func)),
(*expr).clone(),
),
)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_expr() {
let expr = Expr::apply(
0,
Expr::apply(
1,
Expr::sum(Expr::Variable(0), Expr::Variable(1)),
Expr::Const(4),
),
Expr::Const(3),
);
assert_eq!(Expr::eval::<2>(&expr), 7);
}
#[test]
fn test_derivative() {
let add_four = Expr::sum(Expr::Variable(0), Expr::Const(4));
let mul_five = Expr::mul(Expr::Variable(1), Expr::Const(5));
{
let mul_five_then_add_four = Expr::apply(0, add_four.clone(), mul_five.clone());
let mul_then_add_diff = Expr::differentiate(&1, &0, 1, &mul_five_then_add_four);
for i in 3..10 {
// (5x + 4) differentiates to 5
assert_eq!(
Expr::eval::<2>(&Expr::apply(1, mul_then_add_diff.clone(), Expr::Const(i))),
5
);
}
}
{
let add_four_then_mul_five = Expr::apply(1, mul_five.clone(), add_four.clone());
let add_then_mul_diff = Expr::differentiate(&1, &0, 0, &add_four_then_mul_five);
for i in 3..10 {
// ((x + 4) * 5) differentiates to 5
assert_eq!(
Expr::eval::<2>(&Expr::apply(0, add_then_mul_diff.clone(), Expr::Const(i))),
5
);
}
}
}
}

View File

@@ -1,3 +1,5 @@
mod expr_syntax_tree;
use std::iter::Sum; use std::iter::Sum;
use std::ops::{Mul, Sub}; use std::ops::{Mul, Sub};