Automatic differentiation (#5)

This commit is contained in:
Patrick Stevens
2023-03-25 22:19:04 +00:00
committed by GitHub
parent adff7ac3fd
commit 32caf8d7d6
14 changed files with 894 additions and 150 deletions

3
.gitignore vendored
View File

@@ -1,3 +1,4 @@
/target
target/
.idea/
*.iml
.vscode/

34
Cargo.lock generated
View File

@@ -8,6 +8,12 @@ version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "bitvec"
version = "1.0.1"
@@ -42,6 +48,34 @@ name = "little_learner"
version = "0.1.0"
dependencies = [
"immutable-chunkmap",
"ordered-float",
]
[[package]]
name = "little_learner_app"
version = "0.1.0"
dependencies = [
"immutable-chunkmap",
"little_learner",
"ordered-float",
]
[[package]]
name = "num-traits"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
]
[[package]]
name = "ordered-float"
version = "3.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13a384337e997e6860ffbaa83708b2ef329fd8c54cb67a5f64d421e0f943254f"
dependencies = [
"num-traits",
]
[[package]]

View File

@@ -1,9 +1,5 @@
[package]
name = "little_learner"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
immutable-chunkmap = "1.0.5"
[workspace]
members = [
"little_learner",
"little_learner_app"
]

12
flake.lock generated
View File

@@ -49,11 +49,11 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1678898370,
"narHash": "sha256-xTICr1j+uat5hk9FyuPOFGxpWHdJRibwZC+ATi0RbtE=",
"lastModified": 1679705136,
"narHash": "sha256-MDlZUR7wJ3PlPtqwwoGQr3euNOe0vdSSteVVOef7tBY=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "ac718d02867a84b42522a0ece52d841188208f2c",
"rev": "8f40f2f90b9c9032d1b824442cfbbe0dbabd0dbd",
"type": "github"
},
"original": {
@@ -65,11 +65,11 @@
},
"nixpkgs_2": {
"locked": {
"lastModified": 1665296151,
"narHash": "sha256-uOB0oxqxN9K7XGF1hcnY+PQnlQJ+3bP2vCn/+Ru/bbc=",
"lastModified": 1679734080,
"narHash": "sha256-z846xfGLlon6t9lqUzlNtBOmsgQLQIZvR6Lt2dImk1M=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "14ccaaedd95a488dd7ae142757884d8e125b3363",
"rev": "dbf5322e93bcc6cfc52268367a8ad21c09d76fea",
"type": "github"
},
"original": {

View File

@@ -23,7 +23,7 @@
crate2nix,
...
}: let
name = "little_learner";
name = "little_learner_app";
in
utils.lib.eachDefaultSystem
(
@@ -77,7 +77,7 @@
PKG_CONFIG_PATH = "${pkgs.openssl.dev}/lib/pkgconfig";
};
in rec {
packages.${name} = project.rootCrate.build;
packages.${name} = project.workspaceMembers.${name}.build;
# `nix build`
defaultPackage = packages.${name};

147
little_learner/Cargo.lock generated Normal file
View File

@@ -0,0 +1,147 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "arrayvec"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "immutable-chunkmap"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7617eb072b88069788fa9d5cadae34faebca64e5325ec5deaa2b4c96510f9e8c"
dependencies = [
"arrayvec",
"packed_struct",
"packed_struct_codegen",
]
[[package]]
name = "little_learner"
version = "0.1.0"
dependencies = [
"immutable-chunkmap",
"ordered-float",
]
[[package]]
name = "num-traits"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
]
[[package]]
name = "ordered-float"
version = "3.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13a384337e997e6860ffbaa83708b2ef329fd8c54cb67a5f64d421e0f943254f"
dependencies = [
"num-traits",
]
[[package]]
name = "packed_struct"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36b29691432cc9eff8b282278473b63df73bea49bc3ec5e67f31a3ae9c3ec190"
dependencies = [
"bitvec",
"packed_struct_codegen",
]
[[package]]
name = "packed_struct_codegen"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cd6706dfe50d53e0f6aa09e12c034c44faacd23e966ae5a209e8bdb8f179f98"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc"
dependencies = [
"proc-macro2",
]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "unicode-ident"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4"
[[package]]
name = "wyz"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
dependencies = [
"tap",
]

12
little_learner/Cargo.toml Normal file
View File

@@ -0,0 +1,12 @@
[package]
name = "little_learner"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
immutable-chunkmap = "1.0.5"
ordered-float = "3.6.0"
[lib]

View File

@@ -0,0 +1,365 @@
use core::hash::Hash;
use ordered_float::NotNan;
use std::{
collections::{hash_map::Entry, HashMap},
fmt::{Display, Write},
ops::{Add, AddAssign, Div, Mul},
};
pub trait Zero {
fn zero() -> Self;
}
pub trait One {
fn one() -> Self;
}
impl Zero for f64 {
fn zero() -> Self {
0.0
}
}
impl One for f64 {
fn one() -> Self {
1.0
}
}
impl Zero for NotNan<f64> {
fn zero() -> Self {
NotNan::new(0.0).unwrap()
}
}
impl One for NotNan<f64> {
fn one() -> Self {
NotNan::new(1.0).unwrap()
}
}
impl<A> Zero for Differentiable<A>
where
A: Zero,
{
fn zero() -> Differentiable<A> {
Differentiable::Scalar(Scalar::Number(A::zero()))
}
}
impl<A> One for Differentiable<A>
where
A: One,
{
fn one() -> Differentiable<A> {
Differentiable::Scalar(Scalar::Number(A::one()))
}
}
pub trait Exp {
fn exp(self) -> Self;
}
impl Exp for NotNan<f64> {
fn exp(self) -> Self {
NotNan::new(f64::exp(self.into_inner())).expect("expected a non-NaN")
}
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum LinkData<A> {
Addition(Box<Scalar<A>>, Box<Scalar<A>>),
Mul(Box<Scalar<A>>, Box<Scalar<A>>),
Exponent(Box<Scalar<A>>),
Log(Box<Scalar<A>>),
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum Link<A> {
EndOfLink,
Link(LinkData<A>),
}
impl<A> Display for Link<A>
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Link::EndOfLink => f.write_str("<end>"),
Link::Link(LinkData::Addition(left, right)) => {
f.write_fmt(format_args!("({} + {})", left.as_ref(), right.as_ref()))
}
Link::Link(LinkData::Mul(left, right)) => {
f.write_fmt(format_args!("({} * {})", left.as_ref(), right.as_ref()))
}
Link::Link(LinkData::Exponent(arg)) => {
f.write_fmt(format_args!("exp({})", arg.as_ref()))
}
Link::Link(LinkData::Log(arg)) => f.write_fmt(format_args!("log({})", arg.as_ref())),
}
}
}
impl<A> Link<A> {
fn invoke(self, d: &Scalar<A>, z: A, acc: &mut HashMap<Scalar<A>, A>)
where
A: Eq + Hash + AddAssign + Clone + Exp + Mul<Output = A> + Div<Output = A> + Zero + One,
{
match self {
Link::EndOfLink => match acc.entry(d.clone()) {
Entry::Occupied(mut o) => {
let entry = o.get_mut();
*entry += z;
}
Entry::Vacant(v) => {
v.insert(z);
}
},
Link::Link(data) => {
match data {
LinkData::Addition(left, right) => {
// The `z` here reflects the fact that dx/dx = 1, so it's 1 * z.
left.as_ref().clone_link().invoke(&left, z.clone(), acc);
right.as_ref().clone_link().invoke(&right, z, acc);
}
LinkData::Exponent(arg) => {
// d/dx (e^x) = exp x, so exp z * z.
arg.as_ref().clone_link().invoke(
&arg,
z * arg.clone_real_part().exp(),
acc,
);
}
LinkData::Mul(left, right) => {
// d/dx(f g) = f dg/dx + g df/dx
left.as_ref().clone_link().invoke(
&left,
right.clone_real_part() * z.clone(),
acc,
);
right
.as_ref()
.clone_link()
.invoke(&right, left.clone_real_part() * z, acc);
}
LinkData::Log(arg) => {
// d/dx(log y) = 1/y dy/dx
arg.as_ref().clone_link().invoke(
&arg,
A::one() / arg.clone_real_part() * z,
acc,
);
}
}
}
}
}
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum Scalar<A> {
Number(A),
// The value, and the link.
Dual(A, Link<A>),
}
impl<A> Add for Scalar<A>
where
A: Add<Output = A> + Clone,
{
type Output = Scalar<A>;
fn add(self, rhs: Self) -> Self::Output {
Scalar::Dual(
self.clone_real_part() + rhs.clone_real_part(),
Link::Link(LinkData::Addition(Box::new(self), Box::new(rhs))),
)
}
}
impl<A> Mul for Scalar<A>
where
A: Mul<Output = A> + Clone,
{
type Output = Scalar<A>;
fn mul(self, rhs: Self) -> Self::Output {
Scalar::Dual(
self.clone_real_part() * rhs.clone_real_part(),
Link::Link(LinkData::Mul(Box::new(self), Box::new(rhs))),
)
}
}
impl<A> Scalar<A> {
pub fn real_part(&self) -> &A {
match self {
Scalar::Number(a) => a,
Scalar::Dual(a, _) => a,
}
}
fn clone_real_part(&self) -> A
where
A: Clone,
{
match self {
Scalar::Number(a) => (*a).clone(),
Scalar::Dual(a, _) => (*a).clone(),
}
}
pub fn link(self) -> Link<A> {
match self {
Scalar::Dual(_, link) => link,
Scalar::Number(_) => Link::EndOfLink,
}
}
fn clone_link(&self) -> Link<A>
where
A: Clone,
{
match self {
Scalar::Dual(_, data) => data.clone(),
Scalar::Number(_) => Link::EndOfLink,
}
}
fn truncate_dual(self) -> Scalar<A>
where
A: Clone,
{
Scalar::Dual(self.clone_real_part(), Link::EndOfLink)
}
}
impl<A> Display for Scalar<A>
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Scalar::Number(n) => f.write_fmt(format_args!("{}", n)),
Scalar::Dual(n, link) => f.write_fmt(format_args!("{}, link: {}", n, link)),
}
}
}
pub enum Differentiable<A> {
Scalar(Scalar<A>),
Vector(Box<[Differentiable<A>]>),
}
impl<A> Display for Differentiable<A>
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)),
Differentiable::Vector(v) => {
f.write_char('[')?;
for v in v.iter() {
f.write_fmt(format_args!("{}", v))?;
f.write_char(',')?;
}
f.write_char(']')
}
}
}
}
impl<A> Differentiable<A> {
pub fn map<B, F>(&self, f: &F) -> Differentiable<B>
where
F: Fn(Scalar<A>) -> Scalar<B>,
A: Clone,
{
match self {
Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())),
Differentiable::Vector(slice) => {
Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect())
}
}
}
}
impl<A> Differentiable<A>
where
A: Clone + Eq + Hash + AddAssign + Mul<Output = A> + Exp + Div<Output = A> + Zero + One,
{
fn accumulate_gradients_vec(v: &[Differentiable<A>], acc: &mut HashMap<Scalar<A>, A>) {
for v in v.iter().rev() {
v.accumulate_gradients(acc);
}
}
fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) {
match self {
Differentiable::Scalar(y) => {
let k = y.clone_link();
k.invoke(y, A::one(), acc);
}
Differentiable::Vector(y) => Differentiable::accumulate_gradients_vec(y, acc),
}
}
fn grad_once(self, wrt: Differentiable<A>) -> Differentiable<A> {
let mut acc = HashMap::new();
self.accumulate_gradients(&mut acc);
wrt.map(&|d| match acc.get(&d) {
None => Scalar::Number(A::zero()),
Some(x) => Scalar::Number(x.clone()),
})
}
pub fn grad<F>(f: F, theta: Differentiable<A>) -> Differentiable<A>
where
F: Fn(&Differentiable<A>) -> Differentiable<A>,
{
let wrt = theta.map(&Scalar::truncate_dual);
let after_f = f(&wrt);
Differentiable::grad_once(after_f, wrt)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn extract_scalar<'a, A>(d: &'a Differentiable<A>) -> &'a A {
match d {
Differentiable::Scalar(a) => &(a.real_part()),
Differentiable::Vector(_) => panic!("not a scalar"),
}
}
#[test]
fn test_map() {
let v = Differentiable::Vector(
vec![
Differentiable::Scalar(Scalar::Number(NotNan::new(3.0).expect("3 is not NaN"))),
Differentiable::Scalar(Scalar::Number(NotNan::new(4.0).expect("4 is not NaN"))),
]
.into(),
);
let mapped = v.map(&|x: Scalar<NotNan<f64>>| match x {
Scalar::Number(i) => Scalar::Number(i + NotNan::new(1.0).expect("1 is not NaN")),
Scalar::Dual(_, _) => panic!("Not hit"),
});
let v = match mapped {
Differentiable::Scalar(_) => panic!("Not a scalar"),
Differentiable::Vector(v) => v
.as_ref()
.iter()
.map(|d| extract_scalar(d).clone())
.collect::<Vec<_>>(),
};
assert_eq!(v, [4.0, 5.0]);
}
}

View File

@@ -1,5 +1,3 @@
#![allow(dead_code)]
use immutable_chunkmap::map;
use std::ops::{Add, Mul};
@@ -50,24 +48,15 @@ impl<A> Expr<A> {
Expr::Apply(var, Box::new(f), Box::new(arg))
}
pub fn sum(x: Expr<A>, y: Expr<A>) -> Expr<A> {
Expr::Sum(Box::new(x), Box::new(y))
}
pub fn mul(x: Expr<A>, y: Expr<A>) -> Expr<A> {
Expr::Mul(Box::new(x), Box::new(y))
}
pub fn differentiate(one: &A, zero: &A, var: u32, f: &Expr<A>) -> Expr<A>
where
A: Clone,
{
match f {
Expr::Const(_) => Expr::Const(zero.clone()),
Expr::Sum(x, y) => Expr::sum(
Expr::differentiate(one, zero, var, x),
Expr::differentiate(one, zero, var, y),
),
Expr::Sum(x, y) => {
Expr::differentiate(one, zero, var, x) + Expr::differentiate(one, zero, var, y)
}
Expr::Variable(i) => {
if *i == var {
Expr::Const(one.clone())
@@ -75,16 +64,15 @@ impl<A> Expr<A> {
Expr::Const(zero.clone())
}
}
Expr::Mul(x, y) => Expr::sum(
Expr::Mul(x, y) => {
Expr::Mul(
Box::new(Expr::differentiate(one, zero, var, x.as_ref())),
(*y).clone(),
),
Expr::Mul(
) + Expr::Mul(
Box::new(Expr::differentiate(one, zero, var, y.as_ref())),
(*x).clone(),
),
),
)
}
Expr::Apply(new_var, func, expr) => {
if *new_var == var {
panic!(
@@ -106,6 +94,20 @@ impl<A> Expr<A> {
}
}
impl<A> Add for Expr<A> {
type Output = Expr<A>;
fn add(self: Expr<A>, y: Expr<A>) -> Expr<A> {
Expr::Sum(Box::new(self), Box::new(y))
}
}
impl<A> Mul for Expr<A> {
type Output = Expr<A>;
fn mul(self: Expr<A>, y: Expr<A>) -> Expr<A> {
Expr::Mul(Box::new(self), Box::new(y))
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -114,11 +116,7 @@ mod tests {
fn test_expr() {
let expr = Expr::apply(
0,
Expr::apply(
1,
Expr::sum(Expr::Variable(0), Expr::Variable(1)),
Expr::Const(4),
),
Expr::apply(1, Expr::Variable(0) + Expr::Variable(1), Expr::Const(4)),
Expr::Const(3),
);
@@ -127,8 +125,8 @@ mod tests {
#[test]
fn test_derivative() {
let add_four = Expr::sum(Expr::Variable(0), Expr::Const(4));
let mul_five = Expr::mul(Expr::Variable(1), Expr::Const(5));
let add_four = Expr::Variable(0) + Expr::Const(4);
let mul_five = Expr::Variable(1) * Expr::Const(5);
{
let mul_five_then_add_four = Expr::apply(0, add_four.clone(), mul_five.clone());

View File

@@ -0,0 +1,3 @@
pub mod auto_diff;
pub mod expr_syntax_tree;
pub mod tensor;

View File

@@ -0,0 +1,107 @@
#[macro_export]
macro_rules! tensor {
($x:ty , $i: expr) => {[$x; $i]};
($x:ty , $i: expr, $($is:expr),+) => {[tensor!($x, $($is),+); $i]};
}
#[cfg(test)]
mod tests {
#[test]
fn test_tensor_type() {
let _: tensor!(f64, 1, 2, 3) = [[[1.0, 3.0, 6.0], [-1.3, -30.0, -0.0]]];
}
}
pub trait Extensible1<A> {
fn apply<F>(&self, other: &A, op: &F) -> Self
where
F: Fn(&A, &A) -> A;
}
pub trait Extensible2<A> {
fn apply<F>(&self, other: &Self, op: &F) -> Self
where
F: Fn(&A, &A) -> A;
}
impl<A, T, const N: usize> Extensible1<A> for [T; N]
where
T: Extensible1<A> + Copy + Default,
{
fn apply<F>(&self, other: &A, op: &F) -> Self
where
F: Fn(&A, &A) -> A,
{
let mut result = [Default::default(); N];
for (i, coord) in self.iter().enumerate() {
result[i] = T::apply(coord, other, op);
}
result
}
}
impl<A, T, const N: usize> Extensible2<A> for [T; N]
where
T: Extensible2<A> + Copy + Default,
{
fn apply<F>(&self, other: &Self, op: &F) -> Self
where
F: Fn(&A, &A) -> A,
{
let mut result = [Default::default(); N];
for (i, coord) in self.iter().enumerate() {
result[i] = T::apply(coord, &other[i], op);
}
result
}
}
#[macro_export]
macro_rules! extensible1 {
($x: ty) => {
impl Extensible1<$x> for $x {
fn apply<F>(&self, other: &$x, op: &F) -> Self
where
F: Fn(&Self, &Self) -> Self,
{
op(self, other)
}
}
};
}
#[macro_export]
macro_rules! extensible2 {
($x: ty) => {
impl Extensible2<$x> for $x {
fn apply<F>(&self, other: &Self, op: &F) -> Self
where
F: Fn(&Self, &Self) -> Self,
{
op(self, other)
}
}
};
}
extensible1!(u8);
extensible1!(f64);
extensible2!(u8);
extensible2!(f64);
pub fn extension1<T, A, F>(t1: &T, t2: &A, op: F) -> T
where
T: Extensible1<A>,
F: Fn(&A, &A) -> A,
{
t1.apply::<F>(t2, &op)
}
pub fn extension2<T, A, F>(t1: &T, t2: &T, op: F) -> T
where
T: Extensible2<A>,
F: Fn(&A, &A) -> A,
{
t1.apply::<F>(t2, &op)
}

156
little_learner_app/Cargo.lock generated Normal file
View File

@@ -0,0 +1,156 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "arrayvec"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6"
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "bitvec"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c"
dependencies = [
"funty",
"radium",
"tap",
"wyz",
]
[[package]]
name = "funty"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "immutable-chunkmap"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7617eb072b88069788fa9d5cadae34faebca64e5325ec5deaa2b4c96510f9e8c"
dependencies = [
"arrayvec",
"packed_struct",
"packed_struct_codegen",
]
[[package]]
name = "little_learner"
version = "0.1.0"
dependencies = [
"immutable-chunkmap",
"ordered-float",
]
[[package]]
name = "little_learner_app"
version = "0.1.0"
dependencies = [
"immutable-chunkmap",
"little_learner",
"ordered-float",
]
[[package]]
name = "num-traits"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
]
[[package]]
name = "ordered-float"
version = "3.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13a384337e997e6860ffbaa83708b2ef329fd8c54cb67a5f64d421e0f943254f"
dependencies = [
"num-traits",
]
[[package]]
name = "packed_struct"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36b29691432cc9eff8b282278473b63df73bea49bc3ec5e67f31a3ae9c3ec190"
dependencies = [
"bitvec",
"packed_struct_codegen",
]
[[package]]
name = "packed_struct_codegen"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cd6706dfe50d53e0f6aa09e12c034c44faacd23e966ae5a209e8bdb8f179f98"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc"
dependencies = [
"proc-macro2",
]
[[package]]
name = "radium"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "tap"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
[[package]]
name = "unicode-ident"
version = "1.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4"
[[package]]
name = "wyz"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed"
dependencies = [
"tap",
]

View File

@@ -0,0 +1,11 @@
[package]
name = "little_learner_app"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
immutable-chunkmap = "1.0.5"
ordered-float = "3.6.0"
little_learner = { path = "../little_learner" }

View File

@@ -1,4 +1,7 @@
mod expr_syntax_tree;
use little_learner::auto_diff::{Differentiable, Scalar};
use little_learner::tensor;
use little_learner::tensor::{extension2, Extensible2};
use ordered_float::NotNan;
use std::iter::Sum;
use std::ops::{Mul, Sub};
@@ -7,106 +10,6 @@ type Point<A, const N: usize> = [A; N];
type Parameters<A, const N: usize, const M: usize> = [Point<A, N>; M];
#[macro_export]
macro_rules! tensor {
($x:ty , $i: expr) => {[$x; $i]};
($x:ty , $i: expr, $($is:expr),+) => {[tensor!($x, $($is),+); $i]};
}
pub trait Extensible1<A> {
fn apply<F>(&self, other: &A, op: &F) -> Self
where
F: Fn(&A, &A) -> A;
}
pub trait Extensible2<A> {
fn apply<F>(&self, other: &Self, op: &F) -> Self
where
F: Fn(&A, &A) -> A;
}
impl<A, T, const N: usize> Extensible1<A> for [T; N]
where
T: Extensible1<A> + Copy + Default,
{
fn apply<F>(&self, other: &A, op: &F) -> Self
where
F: Fn(&A, &A) -> A,
{
let mut result = [Default::default(); N];
for (i, coord) in self.iter().enumerate() {
result[i] = T::apply(coord, other, op);
}
result
}
}
impl<A, T, const N: usize> Extensible2<A> for [T; N]
where
T: Extensible2<A> + Copy + Default,
{
fn apply<F>(&self, other: &Self, op: &F) -> Self
where
F: Fn(&A, &A) -> A,
{
let mut result = [Default::default(); N];
for (i, coord) in self.iter().enumerate() {
result[i] = T::apply(coord, &other[i], op);
}
result
}
}
#[macro_export]
macro_rules! extensible1 {
($x: ty) => {
impl Extensible1<$x> for $x {
fn apply<F>(&self, other: &$x, op: &F) -> Self
where
F: Fn(&Self, &Self) -> Self,
{
op(self, other)
}
}
};
}
#[macro_export]
macro_rules! extensible2 {
($x: ty) => {
impl Extensible2<$x> for $x {
fn apply<F>(&self, other: &Self, op: &F) -> Self
where
F: Fn(&Self, &Self) -> Self,
{
op(self, other)
}
}
};
}
extensible1!(u8);
extensible1!(f64);
extensible2!(u8);
extensible2!(f64);
pub fn extension1<T, A, F>(t1: &T, t2: &A, op: F) -> T
where
T: Extensible1<A>,
F: Fn(&A, &A) -> A,
{
t1.apply::<F>(t2, &op)
}
pub fn extension2<T, A, F>(t1: &T, t2: &T, op: F) -> T
where
T: Extensible2<A>,
F: Fn(&A, &A) -> A,
{
t1.apply::<F>(t2, &op)
}
fn dot_points<A: Mul, const N: usize>(x: &Point<A, N>, y: &Point<A, N>) -> A
where
A: Sum<<A as Mul>::Output> + Copy + Default + Mul<Output = A> + Extensible2<A>,
@@ -180,6 +83,14 @@ where
result
}
fn square<A>(x: &A) -> A
where
A: Mul<Output = A> + Clone + std::fmt::Display,
{
println!("{}", x);
x.clone() * x.clone()
}
fn main() {
let loss = l2_loss(
predict_line,
@@ -188,16 +99,19 @@ fn main() {
&[0.0099, 0.0],
);
println!("{:?}", loss);
let input_vec = Differentiable::Vector(Box::new([Differentiable::Scalar(Scalar::Number(
NotNan::new(27.0).expect("not nan"),
))]));
let grad = Differentiable::grad(|x| x.map(&|x| square(&x)), input_vec);
println!("{}", grad);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_type() {
let _: tensor!(f64, 1, 2, 3) = [[[1.0, 3.0, 6.0], [-1.3, -30.0, -0.0]]];
}
use little_learner::tensor::extension1;
#[test]
fn test_extension() {