Add rank parameters to autodiff (#6)

This commit is contained in:
Patrick Stevens
2023-03-29 21:00:13 +01:00
committed by GitHub
parent 32caf8d7d6
commit 0d2e5eb277
10 changed files with 685 additions and 362 deletions

View File

@@ -0,0 +1 @@
nightly

View File

@@ -1,265 +1,64 @@
use crate::scalar::Scalar;
use crate::traits::{Exp, One, Zero};
use core::hash::Hash;
use ordered_float::NotNan;
use std::collections::HashMap;
use std::{
collections::{hash_map::Entry, HashMap},
fmt::{Display, Write},
ops::{Add, AddAssign, Div, Mul},
ops::{AddAssign, Div, Mul, Neg},
};
pub trait Zero {
fn zero() -> Self;
}
pub trait One {
fn one() -> Self;
}
impl Zero for f64 {
fn zero() -> Self {
0.0
}
}
impl One for f64 {
fn one() -> Self {
1.0
}
}
impl Zero for NotNan<f64> {
fn zero() -> Self {
NotNan::new(0.0).unwrap()
}
}
impl One for NotNan<f64> {
fn one() -> Self {
NotNan::new(1.0).unwrap()
}
}
impl<A> Zero for Differentiable<A>
impl<A> Zero for DifferentiableHidden<A>
where
A: Zero,
{
fn zero() -> Differentiable<A> {
Differentiable::Scalar(Scalar::Number(A::zero()))
fn zero() -> DifferentiableHidden<A> {
DifferentiableHidden::Scalar(Scalar::Number(A::zero()))
}
}
impl<A> One for Differentiable<A>
impl<A> One for Scalar<A>
where
A: One,
{
fn one() -> Differentiable<A> {
Differentiable::Scalar(Scalar::Number(A::one()))
fn one() -> Scalar<A> {
Scalar::Number(A::one())
}
}
pub trait Exp {
fn exp(self) -> Self;
}
impl Exp for NotNan<f64> {
fn exp(self) -> Self {
NotNan::new(f64::exp(self.into_inner())).expect("expected a non-NaN")
}
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum LinkData<A> {
Addition(Box<Scalar<A>>, Box<Scalar<A>>),
Mul(Box<Scalar<A>>, Box<Scalar<A>>),
Exponent(Box<Scalar<A>>),
Log(Box<Scalar<A>>),
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum Link<A> {
EndOfLink,
Link(LinkData<A>),
}
impl<A> Display for Link<A>
impl<A> One for DifferentiableHidden<A>
where
A: Display,
A: One,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Link::EndOfLink => f.write_str("<end>"),
Link::Link(LinkData::Addition(left, right)) => {
f.write_fmt(format_args!("({} + {})", left.as_ref(), right.as_ref()))
}
Link::Link(LinkData::Mul(left, right)) => {
f.write_fmt(format_args!("({} * {})", left.as_ref(), right.as_ref()))
}
Link::Link(LinkData::Exponent(arg)) => {
f.write_fmt(format_args!("exp({})", arg.as_ref()))
}
Link::Link(LinkData::Log(arg)) => f.write_fmt(format_args!("log({})", arg.as_ref())),
}
fn one() -> DifferentiableHidden<A> {
DifferentiableHidden::Scalar(Scalar::one())
}
}
impl<A> Link<A> {
fn invoke(self, d: &Scalar<A>, z: A, acc: &mut HashMap<Scalar<A>, A>)
where
A: Eq + Hash + AddAssign + Clone + Exp + Mul<Output = A> + Div<Output = A> + Zero + One,
{
match self {
Link::EndOfLink => match acc.entry(d.clone()) {
Entry::Occupied(mut o) => {
let entry = o.get_mut();
*entry += z;
}
Entry::Vacant(v) => {
v.insert(z);
}
},
Link::Link(data) => {
match data {
LinkData::Addition(left, right) => {
// The `z` here reflects the fact that dx/dx = 1, so it's 1 * z.
left.as_ref().clone_link().invoke(&left, z.clone(), acc);
right.as_ref().clone_link().invoke(&right, z, acc);
}
LinkData::Exponent(arg) => {
// d/dx (e^x) = exp x, so exp z * z.
arg.as_ref().clone_link().invoke(
&arg,
z * arg.clone_real_part().exp(),
acc,
);
}
LinkData::Mul(left, right) => {
// d/dx(f g) = f dg/dx + g df/dx
left.as_ref().clone_link().invoke(
&left,
right.clone_real_part() * z.clone(),
acc,
);
right
.as_ref()
.clone_link()
.invoke(&right, left.clone_real_part() * z, acc);
}
LinkData::Log(arg) => {
// d/dx(log y) = 1/y dy/dx
arg.as_ref().clone_link().invoke(
&arg,
A::one() / arg.clone_real_part() * z,
acc,
);
}
}
}
}
}
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum Scalar<A> {
Number(A),
// The value, and the link.
Dual(A, Link<A>),
}
impl<A> Add for Scalar<A>
impl<A> Clone for DifferentiableHidden<A>
where
A: Add<Output = A> + Clone,
A: Clone,
{
type Output = Scalar<A>;
fn add(self, rhs: Self) -> Self::Output {
Scalar::Dual(
self.clone_real_part() + rhs.clone_real_part(),
Link::Link(LinkData::Addition(Box::new(self), Box::new(rhs))),
)
}
}
impl<A> Mul for Scalar<A>
where
A: Mul<Output = A> + Clone,
{
type Output = Scalar<A>;
fn mul(self, rhs: Self) -> Self::Output {
Scalar::Dual(
self.clone_real_part() * rhs.clone_real_part(),
Link::Link(LinkData::Mul(Box::new(self), Box::new(rhs))),
)
}
}
impl<A> Scalar<A> {
pub fn real_part(&self) -> &A {
fn clone(&self) -> Self {
match self {
Scalar::Number(a) => a,
Scalar::Dual(a, _) => a,
}
}
fn clone_real_part(&self) -> A
where
A: Clone,
{
match self {
Scalar::Number(a) => (*a).clone(),
Scalar::Dual(a, _) => (*a).clone(),
}
}
pub fn link(self) -> Link<A> {
match self {
Scalar::Dual(_, link) => link,
Scalar::Number(_) => Link::EndOfLink,
}
}
fn clone_link(&self) -> Link<A>
where
A: Clone,
{
match self {
Scalar::Dual(_, data) => data.clone(),
Scalar::Number(_) => Link::EndOfLink,
}
}
fn truncate_dual(self) -> Scalar<A>
where
A: Clone,
{
Scalar::Dual(self.clone_real_part(), Link::EndOfLink)
}
}
impl<A> Display for Scalar<A>
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Scalar::Number(n) => f.write_fmt(format_args!("{}", n)),
Scalar::Dual(n, link) => f.write_fmt(format_args!("{}, link: {}", n, link)),
Self::Scalar(arg0) => Self::Scalar(arg0.clone()),
Self::Vector(arg0) => Self::Vector(arg0.clone()),
}
}
}
pub enum Differentiable<A> {
enum DifferentiableHidden<A> {
Scalar(Scalar<A>),
Vector(Box<[Differentiable<A>]>),
Vector(Vec<DifferentiableHidden<A>>),
}
impl<A> Display for Differentiable<A>
impl<A> Display for DifferentiableHidden<A>
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)),
Differentiable::Vector(v) => {
DifferentiableHidden::Scalar(s) => f.write_fmt(format_args!("{}", s)),
DifferentiableHidden::Vector(v) => {
f.write_char('[')?;
for v in v.iter() {
f.write_fmt(format_args!("{}", v))?;
@@ -271,26 +70,70 @@ where
}
}
impl<A> Differentiable<A> {
pub fn map<B, F>(&self, f: &F) -> Differentiable<B>
impl<A> DifferentiableHidden<A> {
fn map<B, F>(&self, f: &F) -> DifferentiableHidden<B>
where
F: Fn(Scalar<A>) -> Scalar<B>,
A: Clone,
{
match self {
Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())),
Differentiable::Vector(slice) => {
Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect())
DifferentiableHidden::Scalar(a) => DifferentiableHidden::Scalar(f(a.clone())),
DifferentiableHidden::Vector(slice) => {
DifferentiableHidden::Vector(slice.iter().map(|x| x.map(f)).collect())
}
}
}
fn map2<B, C, F>(&self, other: &DifferentiableHidden<B>, f: &F) -> DifferentiableHidden<C>
where
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
A: Clone,
B: Clone,
{
match (self, other) {
(DifferentiableHidden::Scalar(a), DifferentiableHidden::Scalar(b)) => {
DifferentiableHidden::Scalar(f(a, b))
}
(DifferentiableHidden::Vector(slice_a), DifferentiableHidden::Vector(slice_b)) => {
DifferentiableHidden::Vector(
slice_a
.iter()
.zip(slice_b.iter())
.map(|(a, b)| a.map2(b, f))
.collect(),
)
}
_ => panic!("Wrong shapes!"),
}
}
fn of_slice(input: &[A]) -> DifferentiableHidden<A>
where
A: Clone,
{
DifferentiableHidden::Vector(
input
.iter()
.map(|v| DifferentiableHidden::Scalar(Scalar::Number((*v).clone())))
.collect(),
)
}
}
impl<A> Differentiable<A>
impl<A> DifferentiableHidden<A>
where
A: Clone + Eq + Hash + AddAssign + Mul<Output = A> + Exp + Div<Output = A> + Zero + One,
A: Clone
+ Eq
+ Hash
+ AddAssign
+ Mul<Output = A>
+ Exp
+ Div<Output = A>
+ Zero
+ One
+ Neg<Output = A>,
{
fn accumulate_gradients_vec(v: &[Differentiable<A>], acc: &mut HashMap<Scalar<A>, A>) {
fn accumulate_gradients_vec(v: &[DifferentiableHidden<A>], acc: &mut HashMap<Scalar<A>, A>) {
for v in v.iter().rev() {
v.accumulate_gradients(acc);
}
@@ -298,15 +141,17 @@ where
fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) {
match self {
Differentiable::Scalar(y) => {
DifferentiableHidden::Scalar(y) => {
let k = y.clone_link();
k.invoke(y, A::one(), acc);
}
Differentiable::Vector(y) => Differentiable::accumulate_gradients_vec(y, acc),
DifferentiableHidden::Vector(y) => {
DifferentiableHidden::accumulate_gradients_vec(y, acc)
}
}
}
fn grad_once(self, wrt: Differentiable<A>) -> Differentiable<A> {
fn grad_once(self, wrt: &DifferentiableHidden<A>) -> DifferentiableHidden<A> {
let mut acc = HashMap::new();
self.accumulate_gradients(&mut acc);
@@ -315,34 +160,133 @@ where
Some(x) => Scalar::Number(x.clone()),
})
}
}
pub fn grad<F>(f: F, theta: Differentiable<A>) -> Differentiable<A>
#[derive(Clone)]
pub struct Differentiable<A, const RANK: usize> {
contents: DifferentiableHidden<A>,
}
impl<A, const RANK: usize> Display for Differentiable<A, RANK>
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.contents, f)
}
}
pub fn of_scalar<A>(s: Scalar<A>) -> Differentiable<A, 0> {
Differentiable {
contents: DifferentiableHidden::Scalar(s),
}
}
pub fn to_scalar<A>(s: Differentiable<A, 0>) -> Scalar<A> {
match s.contents {
DifferentiableHidden::Scalar(s) => s,
DifferentiableHidden::Vector(_) => panic!("not a vector"),
}
}
pub fn of_slice<A>(input: &[A]) -> Differentiable<A, 1>
where
A: Clone,
{
Differentiable {
contents: DifferentiableHidden::of_slice(input),
}
}
impl<A, const RANK: usize> Differentiable<A, RANK> {
pub fn of_vector(s: Vec<Differentiable<A, RANK>>) -> Differentiable<A, { RANK + 1 }> {
Differentiable {
contents: DifferentiableHidden::Vector(s.into_iter().map(|v| v.contents).collect()),
}
}
pub fn map<B, F>(s: Differentiable<A, RANK>, f: &F) -> Differentiable<B, RANK>
where
F: Fn(&Differentiable<A>) -> Differentiable<A>,
F: Fn(Scalar<A>) -> Scalar<B>,
A: Clone,
{
let wrt = theta.map(&Scalar::truncate_dual);
let after_f = f(&wrt);
Differentiable::grad_once(after_f, wrt)
Differentiable {
contents: DifferentiableHidden::map(&s.contents, f),
}
}
pub fn map2<B, C, F>(
self: &Differentiable<A, RANK>,
other: &Differentiable<B, RANK>,
f: &F,
) -> Differentiable<C, RANK>
where
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
A: Clone,
B: Clone,
{
Differentiable {
contents: DifferentiableHidden::map2(&self.contents, &other.contents, f),
}
}
pub fn to_vector(s: Differentiable<A, { RANK + 1 }>) -> Vec<Differentiable<A, RANK>> {
match s.contents {
DifferentiableHidden::Scalar(_) => panic!("not a scalar"),
DifferentiableHidden::Vector(v) => v
.into_iter()
.map(|v| Differentiable { contents: v })
.collect(),
}
}
pub fn grad<F>(f: F, theta: Differentiable<A, RANK>) -> Differentiable<A, RANK>
where
F: Fn(Differentiable<A, RANK>) -> Differentiable<A, RANK>,
A: Clone
+ Hash
+ AddAssign
+ Mul<Output = A>
+ Exp
+ Div<Output = A>
+ Zero
+ One
+ Neg<Output = A>
+ Eq,
{
let wrt = theta.contents.map(&Scalar::truncate_dual);
let after_f = f(Differentiable {
contents: wrt.clone(),
});
Differentiable {
contents: DifferentiableHidden::grad_once(after_f.contents, &wrt),
}
}
}
#[cfg(test)]
mod tests {
use ordered_float::NotNan;
use super::*;
fn extract_scalar<'a, A>(d: &'a Differentiable<A>) -> &'a A {
fn extract_scalar<'a, A>(d: &'a DifferentiableHidden<A>) -> &'a A {
match d {
Differentiable::Scalar(a) => &(a.real_part()),
Differentiable::Vector(_) => panic!("not a scalar"),
DifferentiableHidden::Scalar(a) => &(a.real_part()),
DifferentiableHidden::Vector(_) => panic!("not a scalar"),
}
}
#[test]
fn test_map() {
let v = Differentiable::Vector(
let v = DifferentiableHidden::Vector(
vec![
Differentiable::Scalar(Scalar::Number(NotNan::new(3.0).expect("3 is not NaN"))),
Differentiable::Scalar(Scalar::Number(NotNan::new(4.0).expect("4 is not NaN"))),
DifferentiableHidden::Scalar(Scalar::Number(
NotNan::new(3.0).expect("3 is not NaN"),
)),
DifferentiableHidden::Scalar(Scalar::Number(
NotNan::new(4.0).expect("4 is not NaN"),
)),
]
.into(),
);
@@ -352,9 +296,8 @@ mod tests {
});
let v = match mapped {
Differentiable::Scalar(_) => panic!("Not a scalar"),
Differentiable::Vector(v) => v
.as_ref()
DifferentiableHidden::Scalar(_) => panic!("Not a scalar"),
DifferentiableHidden::Vector(v) => v
.iter()
.map(|d| extract_scalar(d).clone())
.collect::<Vec<_>>(),

View File

@@ -1,3 +1,8 @@
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
pub mod auto_diff;
pub mod expr_syntax_tree;
pub mod scalar;
pub mod tensor;
pub mod traits;

View File

@@ -0,0 +1,251 @@
use crate::traits::{Exp, One, Zero};
use core::hash::Hash;
use std::{
collections::{hash_map::Entry, HashMap},
fmt::Display,
iter::Sum,
ops::{Add, AddAssign, Div, Mul, Neg, Sub},
};
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum LinkData<A> {
Addition(Box<Scalar<A>>, Box<Scalar<A>>),
Neg(Box<Scalar<A>>),
Mul(Box<Scalar<A>>, Box<Scalar<A>>),
Exponent(Box<Scalar<A>>),
Log(Box<Scalar<A>>),
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum Link<A> {
EndOfLink,
Link(LinkData<A>),
}
impl<A> Display for Link<A>
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Link::EndOfLink => f.write_str("<end>"),
Link::Link(LinkData::Addition(left, right)) => {
f.write_fmt(format_args!("({} + {})", left.as_ref(), right.as_ref()))
}
Link::Link(LinkData::Neg(arg)) => f.write_fmt(format_args!("(-{})", arg.as_ref())),
Link::Link(LinkData::Mul(left, right)) => {
f.write_fmt(format_args!("({} * {})", left.as_ref(), right.as_ref()))
}
Link::Link(LinkData::Exponent(arg)) => {
f.write_fmt(format_args!("exp({})", arg.as_ref()))
}
Link::Link(LinkData::Log(arg)) => f.write_fmt(format_args!("log({})", arg.as_ref())),
}
}
}
impl<A> Link<A> {
pub fn invoke(self, d: &Scalar<A>, z: A, acc: &mut HashMap<Scalar<A>, A>)
where
A: Eq
+ Hash
+ AddAssign
+ Clone
+ Exp
+ Mul<Output = A>
+ Div<Output = A>
+ Neg<Output = A>
+ Zero
+ One,
{
match self {
Link::EndOfLink => match acc.entry(d.clone()) {
Entry::Occupied(mut o) => {
let entry = o.get_mut();
*entry += z;
}
Entry::Vacant(v) => {
v.insert(z);
}
},
Link::Link(data) => {
match data {
LinkData::Addition(left, right) => {
// The `z` here reflects the fact that dx/dx = 1, so it's 1 * z.
left.as_ref().clone_link().invoke(&left, z.clone(), acc);
right.as_ref().clone_link().invoke(&right, z, acc);
}
LinkData::Exponent(arg) => {
// d/dx (e^x) = exp x, so exp z * z.
arg.as_ref().clone_link().invoke(
&arg,
z * arg.clone_real_part().exp(),
acc,
);
}
LinkData::Mul(left, right) => {
// d/dx(f g) = f dg/dx + g df/dx
left.as_ref().clone_link().invoke(
&left,
right.clone_real_part() * z.clone(),
acc,
);
right
.as_ref()
.clone_link()
.invoke(&right, left.clone_real_part() * z, acc);
}
LinkData::Log(arg) => {
// d/dx(log y) = 1/y dy/dx
arg.as_ref().clone_link().invoke(
&arg,
A::one() / arg.clone_real_part() * z,
acc,
);
}
LinkData::Neg(arg) => {
// d/dx(-y) = - dy/dx
arg.as_ref().clone_link().invoke(&arg, -z, acc);
}
}
}
}
}
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum Scalar<A> {
Number(A),
// The value, and the link.
Dual(A, Link<A>),
}
impl<A> Zero for Scalar<A>
where
A: Zero,
{
fn zero() -> Self {
Scalar::Number(A::zero())
}
}
impl<A> Add for Scalar<A>
where
A: Add<Output = A> + Clone,
{
type Output = Scalar<A>;
fn add(self, rhs: Self) -> Self::Output {
Scalar::Dual(
self.clone_real_part() + rhs.clone_real_part(),
Link::Link(LinkData::Addition(Box::new(self), Box::new(rhs))),
)
}
}
impl<A> Neg for Scalar<A>
where
A: Neg<Output = A> + Clone,
{
type Output = Scalar<A>;
fn neg(self) -> Self::Output {
Scalar::Dual(
-self.clone_real_part(),
Link::Link(LinkData::Neg(Box::new(self))),
)
}
}
impl<A> Sub for Scalar<A>
where
A: Add<Output = A> + Neg<Output = A> + Clone,
{
type Output = Scalar<A>;
fn sub(self, rhs: Self) -> Self::Output {
self + (-rhs)
}
}
impl<A> Mul for Scalar<A>
where
A: Mul<Output = A> + Clone,
{
type Output = Scalar<A>;
fn mul(self, rhs: Self) -> Self::Output {
Scalar::Dual(
self.clone_real_part() * rhs.clone_real_part(),
Link::Link(LinkData::Mul(Box::new(self), Box::new(rhs))),
)
}
}
impl<A> Sum for Scalar<A>
where
A: Zero + Add<Output = A> + Clone,
{
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
let mut answer = Zero::zero();
for i in iter {
answer = answer + i;
}
answer
}
}
impl<A> Scalar<A> {
pub fn real_part(&self) -> &A {
match self {
Scalar::Number(a) => a,
Scalar::Dual(a, _) => a,
}
}
pub fn clone_real_part(&self) -> A
where
A: Clone,
{
match self {
Scalar::Number(a) => (*a).clone(),
Scalar::Dual(a, _) => (*a).clone(),
}
}
pub fn link(self) -> Link<A> {
match self {
Scalar::Dual(_, link) => link,
Scalar::Number(_) => Link::EndOfLink,
}
}
pub fn clone_link(&self) -> Link<A>
where
A: Clone,
{
match self {
Scalar::Dual(_, data) => data.clone(),
Scalar::Number(_) => Link::EndOfLink,
}
}
pub fn truncate_dual(self) -> Scalar<A>
where
A: Clone,
{
Scalar::Dual(self.clone_real_part(), Link::EndOfLink)
}
}
impl<A> Display for Scalar<A>
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Scalar::Number(n) => f.write_fmt(format_args!("{}", n)),
Scalar::Dual(n, link) => f.write_fmt(format_args!("{}, link: {}", n, link)),
}
}
}

View File

@@ -0,0 +1,43 @@
use ordered_float::NotNan;
pub trait Exp {
fn exp(self) -> Self;
}
impl Exp for NotNan<f64> {
fn exp(self) -> Self {
NotNan::new(f64::exp(self.into_inner())).expect("expected a non-NaN")
}
}
pub trait Zero {
fn zero() -> Self;
}
pub trait One {
fn one() -> Self;
}
impl Zero for f64 {
fn zero() -> Self {
0.0
}
}
impl One for f64 {
fn one() -> Self {
1.0
}
}
impl Zero for NotNan<f64> {
fn zero() -> Self {
NotNan::new(0.0).unwrap()
}
}
impl One for NotNan<f64> {
fn one() -> Self {
NotNan::new(1.0).unwrap()
}
}