Make Scalar numlike (#15)

This commit is contained in:
Patrick Stevens
2023-04-30 13:09:16 +01:00
committed by GitHub
parent ae6430aa85
commit 64d98757f4
8 changed files with 218 additions and 123 deletions

View File

@@ -11,6 +11,7 @@ use little_learner::auto_diff::{grad, Differentiable, RankedDifferentiable};
use crate::sample::sample2;
use little_learner::loss::{l2_loss_2, plane_predictor, Predictor};
use little_learner::not_nan::{to_not_nan_1, to_not_nan_2};
use little_learner::scalar::Scalar;
use little_learner::traits::{NumLike, Zero};
use ordered_float::NotNan;
@@ -47,7 +48,7 @@ where
let delta = &delta[i];
i += 1;
// For speed, you might want to truncate_dual this.
let learning_rate = Scalar::make(learning_rate);
let learning_rate = Scalar::make(learning_rate.clone());
Differentiable::map2(
&theta,
&delta.map(&mut |s| s * learning_rate.clone()),
@@ -56,37 +57,27 @@ where
})
}
fn gradient_descent<
'a,
T,
R: Rng,
Point,
F,
G,
const IN_SIZE: usize,
const PARAM_NUM: usize,
const INFLATED_NUM: usize,
>(
fn gradient_descent<'a, T, R: Rng, Point, F, G, const IN_SIZE: usize, const PARAM_NUM: usize>(
mut hyper: GradientDescentHyper<T, R>,
xs: &'a [Point],
to_ranked_differentiable: G,
ys: &[T],
zero_params: [Differentiable<T>; PARAM_NUM],
predictor: Predictor<F, [Differentiable<T>; INFLATED_NUM], [Differentiable<T>; PARAM_NUM]>,
mut predictor: Predictor<F, Scalar<T>, Scalar<T>>,
) -> [Differentiable<T>; PARAM_NUM]
where
T: NumLike + Eq + Hash,
T: NumLike + Hash + Copy + Default,
Point: 'a + Copy,
F: Fn(
RankedDifferentiable<T, IN_SIZE>,
&[Differentiable<T>; INFLATED_NUM],
&[Differentiable<T>; PARAM_NUM],
) -> RankedDifferentiable<T, 1>,
G: for<'b> Fn(&'b [Point]) -> RankedDifferentiable<T, IN_SIZE>,
{
let iterations = hyper.iterations;
iterate(
|theta| {
let out = gradient_descent_step::<T, _, 1, INFLATED_NUM>(
let out = gradient_descent_step::<T, _, 1, PARAM_NUM>(
&mut |x| match hyper.sampling.as_mut() {
None => RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(
l2_loss_2(
@@ -108,30 +99,16 @@ where
)])
}
},
(predictor.inflate)(theta),
theta.map(|x| x.map(&mut predictor.inflate)),
hyper.learning_rate,
);
(predictor.deflate)(out)
out.map(|x| x.map(&mut predictor.deflate))
},
zero_params,
iterations,
)
}
fn to_not_nan_1<T, const N: usize>(xs: [T; N]) -> [NotNan<T>; N]
where
T: ordered_float::Float,
{
xs.map(|x| NotNan::new(x).expect("not nan"))
}
fn to_not_nan_2<T, const N: usize, const M: usize>(xs: [[T; N]; M]) -> [[NotNan<T>; N]; M]
where
T: ordered_float::Float,
{
xs.map(to_not_nan_1)
}
fn collect_vec<T>(input: RankedDifferentiable<NotNan<T>, 1>) -> Vec<T>
where
T: Copy,
@@ -194,91 +171,9 @@ fn main() {
#[cfg(test)]
mod tests {
use super::*;
use little_learner::{
auto_diff::grad,
loss::{
l2_loss_2, line_unranked_predictor, predict_line_2, predict_line_2_unranked,
quadratic_unranked_predictor,
},
};
use little_learner::loss::{line_unranked_predictor, quadratic_unranked_predictor};
use rand::SeedableRng;
use crate::with_tensor::{l2_loss, predict_line};
#[test]
fn loss_example() {
let xs = [2.0, 1.0, 4.0, 3.0];
let ys = [1.8, 1.2, 4.2, 3.3];
let loss = l2_loss_2(
predict_line_2,
RankedDifferentiable::of_slice(&xs),
RankedDifferentiable::of_slice(&ys),
&[
RankedDifferentiable::of_scalar(Scalar::zero()),
RankedDifferentiable::of_scalar(Scalar::zero()),
],
);
assert_eq!(*loss.real_part(), 33.21);
}
#[test]
fn l2_loss_non_autodiff_example() {
let xs = [2.0, 1.0, 4.0, 3.0];
let ys = [1.8, 1.2, 4.2, 3.3];
let loss = l2_loss(predict_line, &xs, &ys, &[0.0099, 0.0]);
assert_eq!(loss, 32.5892403);
}
#[test]
fn grad_example() {
let input_vec = [Differentiable::of_scalar(Scalar::make(
NotNan::new(27.0).expect("not nan"),
))];
let grad: Vec<_> = grad(
|x| {
RankedDifferentiable::of_scalar(
x[0].borrow_scalar().clone() * x[0].borrow_scalar().clone(),
)
},
&input_vec,
)
.into_iter()
.map(|x| x.into_scalar().real_part().into_inner())
.collect();
assert_eq!(grad, [54.0]);
}
#[test]
fn loss_gradient() {
let zero = Scalar::<NotNan<f64>>::zero();
let input_vec = [
RankedDifferentiable::of_scalar(zero.clone()).to_unranked(),
RankedDifferentiable::of_scalar(zero).to_unranked(),
];
let xs = to_not_nan_1([2.0, 1.0, 4.0, 3.0]);
let ys = to_not_nan_1([1.8, 1.2, 4.2, 3.3]);
let grad = grad(
|x| {
RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(l2_loss_2(
predict_line_2_unranked,
RankedDifferentiable::of_slice(&xs),
RankedDifferentiable::of_slice(&ys),
x,
))])
},
&input_vec,
);
assert_eq!(
grad.into_iter()
.map(|x| *(x.into_scalar().real_part()))
.collect::<Vec<_>>(),
[-63.0, -21.0]
);
}
#[test]
fn test_iterate() {
let f = |t: [i32; 3]| t.map(|i| i - 3);

View File

@@ -125,4 +125,12 @@ mod tests {
);
assert_eq!((100.0 * loss).round() / 100.0, 32.59);
}
#[test]
fn l2_loss_non_autodiff_example() {
let xs = [2.0, 1.0, 4.0, 3.0];
let ys = [1.8, 1.2, 4.2, 3.3];
let loss = l2_loss(predict_line, &xs, &ys, &[0.0099, 0.0]);
assert_eq!(loss, 32.5892403);
}
}