Make Scalar numlike (#15)

This commit is contained in:
Patrick Stevens
2023-04-30 13:09:16 +01:00
committed by GitHub
parent ae6430aa85
commit 64d98757f4
8 changed files with 218 additions and 123 deletions

View File

@@ -478,6 +478,7 @@ mod tests {
use ordered_float::NotNan;
use crate::loss::{l2_loss_2, predict_line_2_unranked};
use crate::not_nan::to_not_nan_1;
use super::*;
@@ -539,4 +540,53 @@ mod tests {
.map(|x| f64::from(*x.real_part()));
assert_eq!(grad_vec, [-63.0, -21.0]);
}
#[test]
fn grad_example() {
let input_vec = [Differentiable::of_scalar(Scalar::make(
NotNan::new(27.0).expect("not nan"),
))];
let grad: Vec<_> = grad(
|x| {
RankedDifferentiable::of_scalar(
x[0].borrow_scalar().clone() * x[0].borrow_scalar().clone(),
)
},
&input_vec,
)
.into_iter()
.map(|x| x.into_scalar().real_part().into_inner())
.collect();
assert_eq!(grad, [54.0]);
}
#[test]
fn loss_gradient() {
let zero = Scalar::<NotNan<f64>>::zero();
let input_vec = [
RankedDifferentiable::of_scalar(zero.clone()).to_unranked(),
RankedDifferentiable::of_scalar(zero).to_unranked(),
];
let xs = to_not_nan_1([2.0, 1.0, 4.0, 3.0]);
let ys = to_not_nan_1([1.8, 1.2, 4.2, 3.3]);
let grad = grad(
|x| {
RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(l2_loss_2(
predict_line_2_unranked,
RankedDifferentiable::of_slice(&xs),
RankedDifferentiable::of_slice(&ys),
x,
))])
},
&input_vec,
);
assert_eq!(
grad.into_iter()
.map(|x| *(x.into_scalar().real_part()))
.collect::<Vec<_>>(),
[-63.0, -21.0]
);
}
}

View File

@@ -6,6 +6,7 @@ pub mod auto_diff;
pub mod const_teq;
pub mod expr_syntax_tree;
pub mod loss;
pub mod not_nan;
pub mod scalar;
pub mod tensor;
pub mod traits;

View File

@@ -232,8 +232,7 @@ type ParameterPredictor<T, const INPUT_DIM: usize, const THETA: usize> =
&[Differentiable<T>; THETA],
) -> RankedDifferentiable<T, 1>;
pub const fn plane_predictor<T>(
) -> Predictor<ParameterPredictor<T, 2, 2>, [Differentiable<T>; 2], [Differentiable<T>; 2]>
pub const fn plane_predictor<T>() -> Predictor<ParameterPredictor<T, 2, 2>, Scalar<T>, Scalar<T>>
where
T: NumLike + Default,
{
@@ -245,9 +244,9 @@ where
}
pub const fn line_unranked_predictor<T>(
) -> Predictor<ParameterPredictor<T, 1, 2>, [Differentiable<T>; 2], [Differentiable<T>; 2]>
) -> Predictor<ParameterPredictor<T, 1, 2>, Scalar<T>, Scalar<T>>
where
T: NumLike + Default,
T: NumLike + Default + Copy,
{
Predictor {
predict: predict_line_2_unranked,
@@ -257,7 +256,7 @@ where
}
pub const fn quadratic_unranked_predictor<T>(
) -> Predictor<ParameterPredictor<T, 1, 3>, [Differentiable<T>; 3], [Differentiable<T>; 3]>
) -> Predictor<ParameterPredictor<T, 1, 3>, Scalar<T>, Scalar<T>>
where
T: NumLike + Default,
{
@@ -267,3 +266,28 @@ where
deflate: |x| x,
}
}
#[cfg(test)]
mod test_loss {
use crate::auto_diff::RankedDifferentiable;
use crate::loss::{l2_loss_2, predict_line_2};
use crate::scalar::Scalar;
use crate::traits::Zero;
#[test]
fn loss_example() {
let xs = [2.0, 1.0, 4.0, 3.0];
let ys = [1.8, 1.2, 4.2, 3.3];
let loss = l2_loss_2(
predict_line_2,
RankedDifferentiable::of_slice(&xs),
RankedDifferentiable::of_slice(&ys),
&[
RankedDifferentiable::of_scalar(Scalar::zero()),
RankedDifferentiable::of_scalar(Scalar::zero()),
],
);
assert_eq!(*loss.real_part(), 33.21);
}
}

View File

@@ -0,0 +1,15 @@
use ordered_float::NotNan;
pub fn to_not_nan_1<T, const N: usize>(xs: [T; N]) -> [NotNan<T>; N]
where
T: ordered_float::Float,
{
xs.map(|x| NotNan::new(x).expect("not nan"))
}
pub fn to_not_nan_2<T, const N: usize, const M: usize>(xs: [[T; N]; M]) -> [[NotNan<T>; N]; M]
where
T: ordered_float::Float,
{
xs.map(to_not_nan_1)
}

View File

@@ -14,6 +14,7 @@ pub enum LinkData<A> {
Mul(Box<Scalar<A>>, Box<Scalar<A>>),
Exponent(Box<Scalar<A>>),
Log(Box<Scalar<A>>),
Div(Box<Scalar<A>>, Box<Scalar<A>>),
}
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
@@ -41,6 +42,9 @@ where
f.write_fmt(format_args!("exp({})", arg.as_ref()))
}
Link::Link(LinkData::Log(arg)) => f.write_fmt(format_args!("log({})", arg.as_ref())),
Link::Link(LinkData::Div(left, right)) => {
f.write_fmt(format_args!("({} / {})", left.as_ref(), right.as_ref()))
}
}
}
}
@@ -96,6 +100,21 @@ impl<A> Link<A> {
.clone_link()
.invoke(&right, left.clone_real_part() * z, acc);
}
LinkData::Div(left, right) => {
// d/dx(f / g) = f d(1/g)/dx + (df/dx) / g
// = -f (dg/dx)/g^2 + (df/dx) / g
left.as_ref().clone_link().invoke(
&left,
z.clone() / right.clone_real_part(),
acc,
);
right.as_ref().clone_link().invoke(
&right,
-left.clone_real_part() * z
/ (right.clone_real_part() * right.clone_real_part()),
acc,
)
}
LinkData::Log(arg) => {
// d/dx(log y) = 1/y dy/dx
arg.as_ref().clone_link().invoke(
@@ -144,6 +163,15 @@ where
}
}
impl<A> AddAssign for Scalar<A>
where
A: Add<Output = A> + Clone,
{
fn add_assign(&mut self, rhs: Self) {
*self = self.clone() + rhs
}
}
impl<A> Neg for Scalar<A>
where
A: Neg<Output = A> + Clone,
@@ -190,12 +218,47 @@ where
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
let mut answer = Zero::zero();
for i in iter {
answer = answer + i;
answer += i;
}
answer
}
}
impl<A> Exp for Scalar<A>
where
A: Exp + Clone,
{
fn exp(self) -> Self {
Self::Dual(
self.clone_real_part().exp(),
Link::Link(LinkData::Exponent(Box::new(self))),
)
}
}
impl<A> Div for Scalar<A>
where
A: Div<Output = A> + Clone,
{
type Output = Scalar<A>;
fn div(self, rhs: Self) -> Self::Output {
Self::Dual(
self.clone_real_part() / rhs.clone_real_part(),
Link::Link(LinkData::Div(Box::new(self), Box::new(rhs))),
)
}
}
impl<A> Default for Scalar<A>
where
A: Default,
{
fn default() -> Self {
Scalar::Number(A::default(), None)
}
}
impl<A> Scalar<A> {
pub fn real_part(&self) -> &A {
match self {
@@ -255,3 +318,39 @@ where
}
}
}
#[cfg(test)]
mod test_loss {
use crate::scalar::Scalar;
use ordered_float::NotNan;
use std::collections::HashMap;
#[test]
fn div_gradient() {
let left = Scalar::make(NotNan::new(3.0).expect("not nan"));
let right = Scalar::make(NotNan::new(5.0).expect("not nan"));
let divided = left / right;
assert_eq!(divided.clone_real_part().into_inner(), 3.0 / 5.0);
let mut acc = HashMap::new();
divided
.clone_link()
.invoke(&divided, NotNan::new(1.0).expect("not nan"), &mut acc);
// Derivative of x/5 with respect to x is the constant 1/5
// Derivative of 3/x with respect to x is -3/x^2, so at the value 5 is -3/25
assert_eq!(acc.len(), 2);
for (key, value) in acc {
let key = key.real_part().into_inner();
let value = value.into_inner();
if key < 4.0 {
// This is the numerator.
assert_eq!(key, 3.0);
assert_eq!(value, 1.0 / 5.0);
} else {
// This is the denominator.
assert_eq!(key, 5.0);
assert_eq!(value, -3.0 / 25.0);
}
}
}
}

View File

@@ -1,3 +1,4 @@
use crate::scalar::Scalar;
use ordered_float::NotNan;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Div, Mul, Neg};
@@ -54,11 +55,13 @@ pub trait NumLike:
+ Mul<Output = Self>
+ Div<Output = Self>
+ Sum
+ Default
+ Clone
+ Copy
+ Sized
+ PartialEq
+ Eq
{
}
impl NumLike for NotNan<f64> {}
impl<A> NumLike for Scalar<A> where A: NumLike {}