Make Scalar numlike (#15)
This commit is contained in:
@@ -11,6 +11,7 @@ use little_learner::auto_diff::{grad, Differentiable, RankedDifferentiable};
|
||||
|
||||
use crate::sample::sample2;
|
||||
use little_learner::loss::{l2_loss_2, plane_predictor, Predictor};
|
||||
use little_learner::not_nan::{to_not_nan_1, to_not_nan_2};
|
||||
use little_learner::scalar::Scalar;
|
||||
use little_learner::traits::{NumLike, Zero};
|
||||
use ordered_float::NotNan;
|
||||
@@ -47,7 +48,7 @@ where
|
||||
let delta = &delta[i];
|
||||
i += 1;
|
||||
// For speed, you might want to truncate_dual this.
|
||||
let learning_rate = Scalar::make(learning_rate);
|
||||
let learning_rate = Scalar::make(learning_rate.clone());
|
||||
Differentiable::map2(
|
||||
&theta,
|
||||
&delta.map(&mut |s| s * learning_rate.clone()),
|
||||
@@ -56,37 +57,27 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
fn gradient_descent<
|
||||
'a,
|
||||
T,
|
||||
R: Rng,
|
||||
Point,
|
||||
F,
|
||||
G,
|
||||
const IN_SIZE: usize,
|
||||
const PARAM_NUM: usize,
|
||||
const INFLATED_NUM: usize,
|
||||
>(
|
||||
fn gradient_descent<'a, T, R: Rng, Point, F, G, const IN_SIZE: usize, const PARAM_NUM: usize>(
|
||||
mut hyper: GradientDescentHyper<T, R>,
|
||||
xs: &'a [Point],
|
||||
to_ranked_differentiable: G,
|
||||
ys: &[T],
|
||||
zero_params: [Differentiable<T>; PARAM_NUM],
|
||||
predictor: Predictor<F, [Differentiable<T>; INFLATED_NUM], [Differentiable<T>; PARAM_NUM]>,
|
||||
mut predictor: Predictor<F, Scalar<T>, Scalar<T>>,
|
||||
) -> [Differentiable<T>; PARAM_NUM]
|
||||
where
|
||||
T: NumLike + Eq + Hash,
|
||||
T: NumLike + Hash + Copy + Default,
|
||||
Point: 'a + Copy,
|
||||
F: Fn(
|
||||
RankedDifferentiable<T, IN_SIZE>,
|
||||
&[Differentiable<T>; INFLATED_NUM],
|
||||
&[Differentiable<T>; PARAM_NUM],
|
||||
) -> RankedDifferentiable<T, 1>,
|
||||
G: for<'b> Fn(&'b [Point]) -> RankedDifferentiable<T, IN_SIZE>,
|
||||
{
|
||||
let iterations = hyper.iterations;
|
||||
iterate(
|
||||
|theta| {
|
||||
let out = gradient_descent_step::<T, _, 1, INFLATED_NUM>(
|
||||
let out = gradient_descent_step::<T, _, 1, PARAM_NUM>(
|
||||
&mut |x| match hyper.sampling.as_mut() {
|
||||
None => RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(
|
||||
l2_loss_2(
|
||||
@@ -108,30 +99,16 @@ where
|
||||
)])
|
||||
}
|
||||
},
|
||||
(predictor.inflate)(theta),
|
||||
theta.map(|x| x.map(&mut predictor.inflate)),
|
||||
hyper.learning_rate,
|
||||
);
|
||||
(predictor.deflate)(out)
|
||||
out.map(|x| x.map(&mut predictor.deflate))
|
||||
},
|
||||
zero_params,
|
||||
iterations,
|
||||
)
|
||||
}
|
||||
|
||||
fn to_not_nan_1<T, const N: usize>(xs: [T; N]) -> [NotNan<T>; N]
|
||||
where
|
||||
T: ordered_float::Float,
|
||||
{
|
||||
xs.map(|x| NotNan::new(x).expect("not nan"))
|
||||
}
|
||||
|
||||
fn to_not_nan_2<T, const N: usize, const M: usize>(xs: [[T; N]; M]) -> [[NotNan<T>; N]; M]
|
||||
where
|
||||
T: ordered_float::Float,
|
||||
{
|
||||
xs.map(to_not_nan_1)
|
||||
}
|
||||
|
||||
fn collect_vec<T>(input: RankedDifferentiable<NotNan<T>, 1>) -> Vec<T>
|
||||
where
|
||||
T: Copy,
|
||||
@@ -194,91 +171,9 @@ fn main() {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use little_learner::{
|
||||
auto_diff::grad,
|
||||
loss::{
|
||||
l2_loss_2, line_unranked_predictor, predict_line_2, predict_line_2_unranked,
|
||||
quadratic_unranked_predictor,
|
||||
},
|
||||
};
|
||||
use little_learner::loss::{line_unranked_predictor, quadratic_unranked_predictor};
|
||||
use rand::SeedableRng;
|
||||
|
||||
use crate::with_tensor::{l2_loss, predict_line};
|
||||
|
||||
#[test]
|
||||
fn loss_example() {
|
||||
let xs = [2.0, 1.0, 4.0, 3.0];
|
||||
let ys = [1.8, 1.2, 4.2, 3.3];
|
||||
let loss = l2_loss_2(
|
||||
predict_line_2,
|
||||
RankedDifferentiable::of_slice(&xs),
|
||||
RankedDifferentiable::of_slice(&ys),
|
||||
&[
|
||||
RankedDifferentiable::of_scalar(Scalar::zero()),
|
||||
RankedDifferentiable::of_scalar(Scalar::zero()),
|
||||
],
|
||||
);
|
||||
|
||||
assert_eq!(*loss.real_part(), 33.21);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn l2_loss_non_autodiff_example() {
|
||||
let xs = [2.0, 1.0, 4.0, 3.0];
|
||||
let ys = [1.8, 1.2, 4.2, 3.3];
|
||||
let loss = l2_loss(predict_line, &xs, &ys, &[0.0099, 0.0]);
|
||||
assert_eq!(loss, 32.5892403);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grad_example() {
|
||||
let input_vec = [Differentiable::of_scalar(Scalar::make(
|
||||
NotNan::new(27.0).expect("not nan"),
|
||||
))];
|
||||
|
||||
let grad: Vec<_> = grad(
|
||||
|x| {
|
||||
RankedDifferentiable::of_scalar(
|
||||
x[0].borrow_scalar().clone() * x[0].borrow_scalar().clone(),
|
||||
)
|
||||
},
|
||||
&input_vec,
|
||||
)
|
||||
.into_iter()
|
||||
.map(|x| x.into_scalar().real_part().into_inner())
|
||||
.collect();
|
||||
assert_eq!(grad, [54.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn loss_gradient() {
|
||||
let zero = Scalar::<NotNan<f64>>::zero();
|
||||
let input_vec = [
|
||||
RankedDifferentiable::of_scalar(zero.clone()).to_unranked(),
|
||||
RankedDifferentiable::of_scalar(zero).to_unranked(),
|
||||
];
|
||||
let xs = to_not_nan_1([2.0, 1.0, 4.0, 3.0]);
|
||||
let ys = to_not_nan_1([1.8, 1.2, 4.2, 3.3]);
|
||||
let grad = grad(
|
||||
|x| {
|
||||
RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(l2_loss_2(
|
||||
predict_line_2_unranked,
|
||||
RankedDifferentiable::of_slice(&xs),
|
||||
RankedDifferentiable::of_slice(&ys),
|
||||
x,
|
||||
))])
|
||||
},
|
||||
&input_vec,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
grad.into_iter()
|
||||
.map(|x| *(x.into_scalar().real_part()))
|
||||
.collect::<Vec<_>>(),
|
||||
[-63.0, -21.0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iterate() {
|
||||
let f = |t: [i32; 3]| t.map(|i| i - 3);
|
||||
|
@@ -125,4 +125,12 @@ mod tests {
|
||||
);
|
||||
assert_eq!((100.0 * loss).round() / 100.0, 32.59);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn l2_loss_non_autodiff_example() {
|
||||
let xs = [2.0, 1.0, 4.0, 3.0];
|
||||
let ys = [1.8, 1.2, 4.2, 3.3];
|
||||
let loss = l2_loss(predict_line, &xs, &ys, &[0.0099, 0.0]);
|
||||
assert_eq!(loss, 32.5892403);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user