Symbolic differentiation (#4)
This commit is contained in:
115
Cargo.lock
generated
115
Cargo.lock
generated
@@ -2,6 +2,121 @@
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "arrayvec"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
|
||||
|
||||
[[package]]
|
||||
name = "bitvec"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
|
||||
dependencies = [
|
||||
"funty",
|
||||
"radium",
|
||||
"tap",
|
||||
"wyz",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "funty"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
|
||||
|
||||
[[package]]
|
||||
name = "immutable-chunkmap"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7617eb072b88069788fa9d5cadae34faebca64e5325ec5deaa2b4c96510f9e8c"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"packed_struct",
|
||||
"packed_struct_codegen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "little_learner"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"immutable-chunkmap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packed_struct"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "36b29691432cc9eff8b282278473b63df73bea49bc3ec5e67f31a3ae9c3ec190"
|
||||
dependencies = [
|
||||
"bitvec",
|
||||
"packed_struct_codegen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packed_struct_codegen"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9cd6706dfe50d53e0f6aa09e12c034c44faacd23e966ae5a209e8bdb8f179f98"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.53"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "radium"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tap"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4"
|
||||
|
||||
[[package]]
|
||||
name = "wyz"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
|
||||
dependencies = [
|
||||
"tap",
|
||||
]
|
||||
|
@@ -6,3 +6,4 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
immutable-chunkmap = "1.0.5"
|
||||
|
157
src/expr_syntax_tree.rs
Normal file
157
src/expr_syntax_tree.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use immutable_chunkmap::map;
|
||||
use std::ops::{Add, Mul};
|
||||
|
||||
/*
|
||||
An untyped syntax tree for an expression whose constants are all of type `A`.
|
||||
*/
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum Expr<A> {
|
||||
Const(A),
|
||||
Sum(Box<Expr<A>>, Box<Expr<A>>),
|
||||
Variable(u32),
|
||||
// The first `Expr` here is a function, which may reference the input variable `Variable(i)`.
|
||||
// For example, `(fun x y -> x + y) 3 4` is expressed as:
|
||||
// Apply(0, Apply(1, Sum(Variable(0), Variable(1)), Const(4)), Const(3))
|
||||
Apply(u32, Box<Expr<A>>, Box<Expr<A>>),
|
||||
Mul(Box<Expr<A>>, Box<Expr<A>>),
|
||||
}
|
||||
|
||||
impl<A> Expr<A> {
|
||||
fn eval_inner<const SIZE: usize>(e: &Expr<A>, ctx: &map::Map<u32, A, SIZE>) -> A
|
||||
where
|
||||
A: Clone + Add<Output = A> + Mul<Output = A>,
|
||||
{
|
||||
match &e {
|
||||
Expr::Const(x) => x.clone(),
|
||||
Expr::Sum(x, y) => Expr::eval_inner(x, ctx) + Expr::eval_inner(y, ctx),
|
||||
Expr::Variable(id) => ctx
|
||||
.get(id)
|
||||
.unwrap_or_else(|| panic!("No binding found for free variable {}", id))
|
||||
.clone(),
|
||||
Expr::Apply(variable, func, arg) => {
|
||||
let arg = Expr::eval_inner(arg, ctx);
|
||||
let (updated_context, _) = ctx.insert(*variable, arg);
|
||||
Expr::eval_inner(func, &updated_context)
|
||||
}
|
||||
Expr::Mul(x, y) => Expr::eval_inner(x, ctx) * Expr::eval_inner(y, ctx),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn eval<const MAX_VAR_NUM: usize>(e: &Expr<A>) -> A
|
||||
where
|
||||
A: Clone + Add<Output = A> + Mul<Output = A>,
|
||||
{
|
||||
Expr::eval_inner(e, &map::Map::<u32, A, MAX_VAR_NUM>::new())
|
||||
}
|
||||
|
||||
pub fn apply(var: u32, f: Expr<A>, arg: Expr<A>) -> Expr<A> {
|
||||
Expr::Apply(var, Box::new(f), Box::new(arg))
|
||||
}
|
||||
|
||||
pub fn sum(x: Expr<A>, y: Expr<A>) -> Expr<A> {
|
||||
Expr::Sum(Box::new(x), Box::new(y))
|
||||
}
|
||||
|
||||
pub fn mul(x: Expr<A>, y: Expr<A>) -> Expr<A> {
|
||||
Expr::Mul(Box::new(x), Box::new(y))
|
||||
}
|
||||
|
||||
pub fn differentiate(one: &A, zero: &A, var: u32, f: &Expr<A>) -> Expr<A>
|
||||
where
|
||||
A: Clone,
|
||||
{
|
||||
match f {
|
||||
Expr::Const(_) => Expr::Const(zero.clone()),
|
||||
Expr::Sum(x, y) => Expr::sum(
|
||||
Expr::differentiate(one, zero, var, x),
|
||||
Expr::differentiate(one, zero, var, y),
|
||||
),
|
||||
Expr::Variable(i) => {
|
||||
if *i == var {
|
||||
Expr::Const(one.clone())
|
||||
} else {
|
||||
Expr::Const(zero.clone())
|
||||
}
|
||||
}
|
||||
Expr::Mul(x, y) => Expr::sum(
|
||||
Expr::Mul(
|
||||
Box::new(Expr::differentiate(one, zero, var, x.as_ref())),
|
||||
(*y).clone(),
|
||||
),
|
||||
Expr::Mul(
|
||||
Box::new(Expr::differentiate(one, zero, var, y.as_ref())),
|
||||
(*x).clone(),
|
||||
),
|
||||
),
|
||||
Expr::Apply(new_var, func, expr) => {
|
||||
if *new_var == var {
|
||||
panic!(
|
||||
"cannot differentiate with respect to variable {} that's been assigned",
|
||||
var
|
||||
)
|
||||
}
|
||||
let expr_deriv = Expr::differentiate(one, zero, var, expr);
|
||||
Expr::mul(
|
||||
expr_deriv,
|
||||
Expr::Apply(
|
||||
*new_var,
|
||||
Box::new(Expr::differentiate(one, zero, *new_var, func)),
|
||||
(*expr).clone(),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_expr() {
|
||||
let expr = Expr::apply(
|
||||
0,
|
||||
Expr::apply(
|
||||
1,
|
||||
Expr::sum(Expr::Variable(0), Expr::Variable(1)),
|
||||
Expr::Const(4),
|
||||
),
|
||||
Expr::Const(3),
|
||||
);
|
||||
|
||||
assert_eq!(Expr::eval::<2>(&expr), 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derivative() {
|
||||
let add_four = Expr::sum(Expr::Variable(0), Expr::Const(4));
|
||||
let mul_five = Expr::mul(Expr::Variable(1), Expr::Const(5));
|
||||
|
||||
{
|
||||
let mul_five_then_add_four = Expr::apply(0, add_four.clone(), mul_five.clone());
|
||||
let mul_then_add_diff = Expr::differentiate(&1, &0, 1, &mul_five_then_add_four);
|
||||
for i in 3..10 {
|
||||
// (5x + 4) differentiates to 5
|
||||
assert_eq!(
|
||||
Expr::eval::<2>(&Expr::apply(1, mul_then_add_diff.clone(), Expr::Const(i))),
|
||||
5
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let add_four_then_mul_five = Expr::apply(1, mul_five.clone(), add_four.clone());
|
||||
let add_then_mul_diff = Expr::differentiate(&1, &0, 0, &add_four_then_mul_five);
|
||||
for i in 3..10 {
|
||||
// ((x + 4) * 5) differentiates to 5
|
||||
assert_eq!(
|
||||
Expr::eval::<2>(&Expr::apply(0, add_then_mul_diff.clone(), Expr::Const(i))),
|
||||
5
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@@ -1,3 +1,5 @@
|
||||
mod expr_syntax_tree;
|
||||
|
||||
use std::iter::Sum;
|
||||
use std::ops::{Mul, Sub};
|
||||
|
||||
|
Reference in New Issue
Block a user