Make Scalar numlike (#15)

This commit is contained in:
Patrick Stevens
2023-04-30 13:09:16 +01:00
committed by GitHub
parent ae6430aa85
commit 64d98757f4
8 changed files with 218 additions and 123 deletions

View File

@@ -478,6 +478,7 @@ mod tests {
use ordered_float::NotNan;
use crate::loss::{l2_loss_2, predict_line_2_unranked};
use crate::not_nan::to_not_nan_1;
use super::*;
@@ -539,4 +540,53 @@ mod tests {
.map(|x| f64::from(*x.real_part()));
assert_eq!(grad_vec, [-63.0, -21.0]);
}
#[test]
fn grad_example() {
let input_vec = [Differentiable::of_scalar(Scalar::make(
NotNan::new(27.0).expect("not nan"),
))];
let grad: Vec<_> = grad(
|x| {
RankedDifferentiable::of_scalar(
x[0].borrow_scalar().clone() * x[0].borrow_scalar().clone(),
)
},
&input_vec,
)
.into_iter()
.map(|x| x.into_scalar().real_part().into_inner())
.collect();
assert_eq!(grad, [54.0]);
}
#[test]
fn loss_gradient() {
let zero = Scalar::<NotNan<f64>>::zero();
let input_vec = [
RankedDifferentiable::of_scalar(zero.clone()).to_unranked(),
RankedDifferentiable::of_scalar(zero).to_unranked(),
];
let xs = to_not_nan_1([2.0, 1.0, 4.0, 3.0]);
let ys = to_not_nan_1([1.8, 1.2, 4.2, 3.3]);
let grad = grad(
|x| {
RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(l2_loss_2(
predict_line_2_unranked,
RankedDifferentiable::of_slice(&xs),
RankedDifferentiable::of_slice(&ys),
x,
))])
},
&input_vec,
);
assert_eq!(
grad.into_iter()
.map(|x| *(x.into_scalar().real_part()))
.collect::<Vec<_>>(),
[-63.0, -21.0]
);
}
}

View File

@@ -6,6 +6,7 @@ pub mod auto_diff;
pub mod const_teq;
pub mod expr_syntax_tree;
pub mod loss;
pub mod not_nan;
pub mod scalar;
pub mod tensor;
pub mod traits;

View File

@@ -232,8 +232,7 @@ type ParameterPredictor<T, const INPUT_DIM: usize, const THETA: usize> =
&[Differentiable<T>; THETA],
) -> RankedDifferentiable<T, 1>;
pub const fn plane_predictor<T>(
) -> Predictor<ParameterPredictor<T, 2, 2>, [Differentiable<T>; 2], [Differentiable<T>; 2]>
pub const fn plane_predictor<T>() -> Predictor<ParameterPredictor<T, 2, 2>, Scalar<T>, Scalar<T>>
where
T: NumLike + Default,
{
@@ -245,9 +244,9 @@ where
}
pub const fn line_unranked_predictor<T>(
) -> Predictor<ParameterPredictor<T, 1, 2>, [Differentiable<T>; 2], [Differentiable<T>; 2]>
) -> Predictor<ParameterPredictor<T, 1, 2>, Scalar<T>, Scalar<T>>
where
T: NumLike + Default,
T: NumLike + Default + Copy,
{
Predictor {
predict: predict_line_2_unranked,
@@ -257,7 +256,7 @@ where
}
pub const fn quadratic_unranked_predictor<T>(
) -> Predictor<ParameterPredictor<T, 1, 3>, [Differentiable<T>; 3], [Differentiable<T>; 3]>
) -> Predictor<ParameterPredictor<T, 1, 3>, Scalar<T>, Scalar<T>>
where
T: NumLike + Default,
{
@@ -267,3 +266,28 @@ where
deflate: |x| x,
}
}
#[cfg(test)]
mod test_loss {
use crate::auto_diff::RankedDifferentiable;
use crate::loss::{l2_loss_2, predict_line_2};
use crate::scalar::Scalar;
use crate::traits::Zero;
#[test]
fn loss_example() {
let xs = [2.0, 1.0, 4.0, 3.0];
let ys = [1.8, 1.2, 4.2, 3.3];
let loss = l2_loss_2(
predict_line_2,
RankedDifferentiable::of_slice(&xs),
RankedDifferentiable::of_slice(&ys),
&[
RankedDifferentiable::of_scalar(Scalar::zero()),
RankedDifferentiable::of_scalar(Scalar::zero()),
],
);
assert_eq!(*loss.real_part(), 33.21);
}
}

View File

@@ -0,0 +1,15 @@
use ordered_float::NotNan;
pub fn to_not_nan_1<T, const N: usize>(xs: [T; N]) -> [NotNan<T>; N]
where
T: ordered_float::Float,
{
xs.map(|x| NotNan::new(x).expect("not nan"))
}
pub fn to_not_nan_2<T, const N: usize, const M: usize>(xs: [[T; N]; M]) -> [[NotNan<T>; N]; M]
where
T: ordered_float::Float,
{
xs.map(to_not_nan_1)
}

View File

@@ -14,6 +14,7 @@ pub enum LinkData<A> {
Mul(Box<Scalar<A>>, Box<Scalar<A>>),
Exponent(Box<Scalar<A>>),
Log(Box<Scalar<A>>),
Div(Box<Scalar<A>>, Box<Scalar<A>>),
}
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
@@ -41,6 +42,9 @@ where
f.write_fmt(format_args!("exp({})", arg.as_ref()))
}
Link::Link(LinkData::Log(arg)) => f.write_fmt(format_args!("log({})", arg.as_ref())),
Link::Link(LinkData::Div(left, right)) => {
f.write_fmt(format_args!("({} / {})", left.as_ref(), right.as_ref()))
}
}
}
}
@@ -96,6 +100,21 @@ impl<A> Link<A> {
.clone_link()
.invoke(&right, left.clone_real_part() * z, acc);
}
LinkData::Div(left, right) => {
// d/dx(f / g) = f d(1/g)/dx + (df/dx) / g
// = -f (dg/dx)/g^2 + (df/dx) / g
left.as_ref().clone_link().invoke(
&left,
z.clone() / right.clone_real_part(),
acc,
);
right.as_ref().clone_link().invoke(
&right,
-left.clone_real_part() * z
/ (right.clone_real_part() * right.clone_real_part()),
acc,
)
}
LinkData::Log(arg) => {
// d/dx(log y) = 1/y dy/dx
arg.as_ref().clone_link().invoke(
@@ -144,6 +163,15 @@ where
}
}
impl<A> AddAssign for Scalar<A>
where
A: Add<Output = A> + Clone,
{
fn add_assign(&mut self, rhs: Self) {
*self = self.clone() + rhs
}
}
impl<A> Neg for Scalar<A>
where
A: Neg<Output = A> + Clone,
@@ -190,12 +218,47 @@ where
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
let mut answer = Zero::zero();
for i in iter {
answer = answer + i;
answer += i;
}
answer
}
}
impl<A> Exp for Scalar<A>
where
A: Exp + Clone,
{
fn exp(self) -> Self {
Self::Dual(
self.clone_real_part().exp(),
Link::Link(LinkData::Exponent(Box::new(self))),
)
}
}
impl<A> Div for Scalar<A>
where
A: Div<Output = A> + Clone,
{
type Output = Scalar<A>;
fn div(self, rhs: Self) -> Self::Output {
Self::Dual(
self.clone_real_part() / rhs.clone_real_part(),
Link::Link(LinkData::Div(Box::new(self), Box::new(rhs))),
)
}
}
impl<A> Default for Scalar<A>
where
A: Default,
{
fn default() -> Self {
Scalar::Number(A::default(), None)
}
}
impl<A> Scalar<A> {
pub fn real_part(&self) -> &A {
match self {
@@ -255,3 +318,39 @@ where
}
}
}
#[cfg(test)]
mod test_loss {
use crate::scalar::Scalar;
use ordered_float::NotNan;
use std::collections::HashMap;
#[test]
fn div_gradient() {
let left = Scalar::make(NotNan::new(3.0).expect("not nan"));
let right = Scalar::make(NotNan::new(5.0).expect("not nan"));
let divided = left / right;
assert_eq!(divided.clone_real_part().into_inner(), 3.0 / 5.0);
let mut acc = HashMap::new();
divided
.clone_link()
.invoke(&divided, NotNan::new(1.0).expect("not nan"), &mut acc);
// Derivative of x/5 with respect to x is the constant 1/5
// Derivative of 3/x with respect to x is -3/x^2, so at the value 5 is -3/25
assert_eq!(acc.len(), 2);
for (key, value) in acc {
let key = key.real_part().into_inner();
let value = value.into_inner();
if key < 4.0 {
// This is the numerator.
assert_eq!(key, 3.0);
assert_eq!(value, 1.0 / 5.0);
} else {
// This is the denominator.
assert_eq!(key, 5.0);
assert_eq!(value, -3.0 / 25.0);
}
}
}
}

View File

@@ -1,3 +1,4 @@
use crate::scalar::Scalar;
use ordered_float::NotNan;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Div, Mul, Neg};
@@ -54,11 +55,13 @@ pub trait NumLike:
+ Mul<Output = Self>
+ Div<Output = Self>
+ Sum
+ Default
+ Clone
+ Copy
+ Sized
+ PartialEq
+ Eq
{
}
impl NumLike for NotNan<f64> {}
impl<A> NumLike for Scalar<A> where A: NumLike {}

View File

@@ -11,6 +11,7 @@ use little_learner::auto_diff::{grad, Differentiable, RankedDifferentiable};
use crate::sample::sample2;
use little_learner::loss::{l2_loss_2, plane_predictor, Predictor};
use little_learner::not_nan::{to_not_nan_1, to_not_nan_2};
use little_learner::scalar::Scalar;
use little_learner::traits::{NumLike, Zero};
use ordered_float::NotNan;
@@ -47,7 +48,7 @@ where
let delta = &delta[i];
i += 1;
// For speed, you might want to truncate_dual this.
let learning_rate = Scalar::make(learning_rate);
let learning_rate = Scalar::make(learning_rate.clone());
Differentiable::map2(
&theta,
&delta.map(&mut |s| s * learning_rate.clone()),
@@ -56,37 +57,27 @@ where
})
}
fn gradient_descent<
'a,
T,
R: Rng,
Point,
F,
G,
const IN_SIZE: usize,
const PARAM_NUM: usize,
const INFLATED_NUM: usize,
>(
fn gradient_descent<'a, T, R: Rng, Point, F, G, const IN_SIZE: usize, const PARAM_NUM: usize>(
mut hyper: GradientDescentHyper<T, R>,
xs: &'a [Point],
to_ranked_differentiable: G,
ys: &[T],
zero_params: [Differentiable<T>; PARAM_NUM],
predictor: Predictor<F, [Differentiable<T>; INFLATED_NUM], [Differentiable<T>; PARAM_NUM]>,
mut predictor: Predictor<F, Scalar<T>, Scalar<T>>,
) -> [Differentiable<T>; PARAM_NUM]
where
T: NumLike + Eq + Hash,
T: NumLike + Hash + Copy + Default,
Point: 'a + Copy,
F: Fn(
RankedDifferentiable<T, IN_SIZE>,
&[Differentiable<T>; INFLATED_NUM],
&[Differentiable<T>; PARAM_NUM],
) -> RankedDifferentiable<T, 1>,
G: for<'b> Fn(&'b [Point]) -> RankedDifferentiable<T, IN_SIZE>,
{
let iterations = hyper.iterations;
iterate(
|theta| {
let out = gradient_descent_step::<T, _, 1, INFLATED_NUM>(
let out = gradient_descent_step::<T, _, 1, PARAM_NUM>(
&mut |x| match hyper.sampling.as_mut() {
None => RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(
l2_loss_2(
@@ -108,30 +99,16 @@ where
)])
}
},
(predictor.inflate)(theta),
theta.map(|x| x.map(&mut predictor.inflate)),
hyper.learning_rate,
);
(predictor.deflate)(out)
out.map(|x| x.map(&mut predictor.deflate))
},
zero_params,
iterations,
)
}
fn to_not_nan_1<T, const N: usize>(xs: [T; N]) -> [NotNan<T>; N]
where
T: ordered_float::Float,
{
xs.map(|x| NotNan::new(x).expect("not nan"))
}
fn to_not_nan_2<T, const N: usize, const M: usize>(xs: [[T; N]; M]) -> [[NotNan<T>; N]; M]
where
T: ordered_float::Float,
{
xs.map(to_not_nan_1)
}
fn collect_vec<T>(input: RankedDifferentiable<NotNan<T>, 1>) -> Vec<T>
where
T: Copy,
@@ -194,91 +171,9 @@ fn main() {
#[cfg(test)]
mod tests {
use super::*;
use little_learner::{
auto_diff::grad,
loss::{
l2_loss_2, line_unranked_predictor, predict_line_2, predict_line_2_unranked,
quadratic_unranked_predictor,
},
};
use little_learner::loss::{line_unranked_predictor, quadratic_unranked_predictor};
use rand::SeedableRng;
use crate::with_tensor::{l2_loss, predict_line};
#[test]
fn loss_example() {
let xs = [2.0, 1.0, 4.0, 3.0];
let ys = [1.8, 1.2, 4.2, 3.3];
let loss = l2_loss_2(
predict_line_2,
RankedDifferentiable::of_slice(&xs),
RankedDifferentiable::of_slice(&ys),
&[
RankedDifferentiable::of_scalar(Scalar::zero()),
RankedDifferentiable::of_scalar(Scalar::zero()),
],
);
assert_eq!(*loss.real_part(), 33.21);
}
#[test]
fn l2_loss_non_autodiff_example() {
let xs = [2.0, 1.0, 4.0, 3.0];
let ys = [1.8, 1.2, 4.2, 3.3];
let loss = l2_loss(predict_line, &xs, &ys, &[0.0099, 0.0]);
assert_eq!(loss, 32.5892403);
}
#[test]
fn grad_example() {
let input_vec = [Differentiable::of_scalar(Scalar::make(
NotNan::new(27.0).expect("not nan"),
))];
let grad: Vec<_> = grad(
|x| {
RankedDifferentiable::of_scalar(
x[0].borrow_scalar().clone() * x[0].borrow_scalar().clone(),
)
},
&input_vec,
)
.into_iter()
.map(|x| x.into_scalar().real_part().into_inner())
.collect();
assert_eq!(grad, [54.0]);
}
#[test]
fn loss_gradient() {
let zero = Scalar::<NotNan<f64>>::zero();
let input_vec = [
RankedDifferentiable::of_scalar(zero.clone()).to_unranked(),
RankedDifferentiable::of_scalar(zero).to_unranked(),
];
let xs = to_not_nan_1([2.0, 1.0, 4.0, 3.0]);
let ys = to_not_nan_1([1.8, 1.2, 4.2, 3.3]);
let grad = grad(
|x| {
RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(l2_loss_2(
predict_line_2_unranked,
RankedDifferentiable::of_slice(&xs),
RankedDifferentiable::of_slice(&ys),
x,
))])
},
&input_vec,
);
assert_eq!(
grad.into_iter()
.map(|x| *(x.into_scalar().real_part()))
.collect::<Vec<_>>(),
[-63.0, -21.0]
);
}
#[test]
fn test_iterate() {
let f = |t: [i32; 3]| t.map(|i| i - 3);

View File

@@ -125,4 +125,12 @@ mod tests {
);
assert_eq!((100.0 * loss).round() / 100.0, 32.59);
}
#[test]
fn l2_loss_non_autodiff_example() {
let xs = [2.0, 1.0, 4.0, 3.0];
let ys = [1.8, 1.2, 4.2, 3.3];
let loss = l2_loss(predict_line, &xs, &ys, &[0.0099, 0.0]);
assert_eq!(loss, 32.5892403);
}
}