Implement Adam, and fix RMS (#22)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -3,3 +3,4 @@ target/
|
||||
*.iml
|
||||
.vscode/
|
||||
.profile*
|
||||
.DS_Store
|
||||
|
@@ -2,7 +2,7 @@ use crate::auto_diff::{grad, Differentiable, RankedDifferentiable};
|
||||
use crate::hyper;
|
||||
use crate::loss::l2_loss_2;
|
||||
use crate::predictor::Predictor;
|
||||
use crate::sample::sample2;
|
||||
use crate::sample;
|
||||
use crate::traits::NumLike;
|
||||
use rand::Rng;
|
||||
use std::hash::Hash;
|
||||
@@ -105,7 +105,7 @@ where
|
||||
),
|
||||
)]),
|
||||
Some((rng, batch_size)) => {
|
||||
let (sampled_xs, sampled_ys) = sample2(rng, *batch_size, xs, ys);
|
||||
let (sampled_xs, sampled_ys) = sample::take_2(rng, *batch_size, xs, ys);
|
||||
RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(
|
||||
l2_loss_2(
|
||||
&predictor.predict,
|
||||
@@ -369,7 +369,7 @@ mod tests {
|
||||
fn test_with_rms() {
|
||||
let beta = NotNan::new(0.9).expect("not nan");
|
||||
let stabilizer = NotNan::new(0.00000001).expect("not nan");
|
||||
let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.001).expect("not nan"), 3000)
|
||||
let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.01).expect("not nan"), 3000)
|
||||
.with_stabilizer(stabilizer)
|
||||
.with_beta(beta);
|
||||
|
||||
@@ -404,7 +404,56 @@ mod tests {
|
||||
.map(|x| x.into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(fitted_theta0, [3.985_350_099_342_649, 1.9745945728216352]);
|
||||
assert_eq!(fitted_theta1, 6.164_222_983_181_168);
|
||||
assert_eq!(fitted_theta0, [3.9746454441720851, 1.9714549220774951]);
|
||||
assert_eq!(fitted_theta1, 6.1645790482740361);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_adam() {
|
||||
let beta = NotNan::new(0.9).expect("not nan");
|
||||
let stabilizer = NotNan::new(0.00000001).expect("not nan");
|
||||
let mu = NotNan::new(0.85).expect("not nan");
|
||||
// Erratum in the book: they printed 0.001 but intended 0.01.
|
||||
let hyper = hyper::AdamGradientDescent::default(NotNan::new(0.01).expect("not nan"), 1500)
|
||||
.with_stabilizer(stabilizer)
|
||||
.with_beta(beta)
|
||||
.with_mu(mu);
|
||||
|
||||
let iterated = {
|
||||
let xs = to_not_nan_2(PLANE_XS);
|
||||
let ys = to_not_nan_1(PLANE_YS);
|
||||
let zero_params = [
|
||||
RankedDifferentiable::of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()])
|
||||
.to_unranked(),
|
||||
Differentiable::of_scalar(Scalar::zero()),
|
||||
];
|
||||
|
||||
gradient_descent(
|
||||
hyper,
|
||||
&xs,
|
||||
RankedDifferentiableTagged::of_slice_2::<_, 2>,
|
||||
&ys,
|
||||
zero_params,
|
||||
predictor::adam(predict_plane),
|
||||
hyper::AdamGradientDescent::to_immutable,
|
||||
)
|
||||
};
|
||||
|
||||
let [theta0, theta1] = iterated;
|
||||
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
|
||||
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
|
||||
|
||||
let fitted_theta0 = theta0
|
||||
.collect()
|
||||
.iter()
|
||||
.map(|x| x.into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(
|
||||
fitted_theta0,
|
||||
[3.980_262_420_345_729_5, 1.977_071_898_301_443_9]
|
||||
);
|
||||
assert_eq!(fitted_theta1, 6.170_196_024_282_712_5);
|
||||
}
|
||||
}
|
||||
|
@@ -1,4 +1,4 @@
|
||||
use crate::predictor::{NakedHypers, RmsHyper, VelocityHypers};
|
||||
use crate::predictor::{AdamHyper, NakedHypers, RmsHyper, VelocityHypers};
|
||||
use crate::traits::{NumLike, Zero};
|
||||
use rand::rngs::StdRng;
|
||||
|
||||
@@ -124,6 +124,23 @@ impl<A, Rng> From<VelocityGradientDescent<A, Rng>> for BaseGradientDescent<Rng>
|
||||
}
|
||||
}
|
||||
|
||||
fn ten<A>() -> A
|
||||
where
|
||||
A: NumLike,
|
||||
{
|
||||
let two = A::one() + A::one();
|
||||
two.clone() * two.clone() * two.clone() + two
|
||||
}
|
||||
|
||||
fn one_ten_k<A>() -> A
|
||||
where
|
||||
A: NumLike,
|
||||
{
|
||||
let one_tenth = A::one() / ten();
|
||||
let one_hundredth = one_tenth.clone() * one_tenth;
|
||||
one_hundredth.clone() * one_hundredth
|
||||
}
|
||||
|
||||
pub struct RmsGradientDescent<A, Rng> {
|
||||
base: BaseGradientDescent<Rng>,
|
||||
rms: RmsHyper<A>,
|
||||
@@ -134,17 +151,11 @@ impl<A> RmsGradientDescent<A, StdRng> {
|
||||
where
|
||||
A: NumLike,
|
||||
{
|
||||
let two = A::one() + A::one();
|
||||
let ten = two.clone() * two.clone() * two.clone() + two;
|
||||
let one_tenth = A::one() / ten.clone();
|
||||
let one_hundredth = one_tenth.clone() * one_tenth;
|
||||
let one_ten_k = one_hundredth.clone() * one_hundredth;
|
||||
|
||||
RmsGradientDescent {
|
||||
base: BaseGradientDescent::new(iterations),
|
||||
rms: RmsHyper {
|
||||
stabilizer: one_ten_k.clone() * one_ten_k,
|
||||
beta: A::one() + -(A::one() / ten),
|
||||
stabilizer: one_ten_k::<A>() * one_ten_k(),
|
||||
beta: A::one() + -(A::one() / ten()),
|
||||
learning_rate,
|
||||
},
|
||||
}
|
||||
@@ -189,3 +200,66 @@ impl<A, Rng> From<RmsGradientDescent<A, Rng>> for BaseGradientDescent<Rng> {
|
||||
val.base
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AdamGradientDescent<A, Rng> {
|
||||
base: BaseGradientDescent<Rng>,
|
||||
adam: AdamHyper<A>,
|
||||
}
|
||||
|
||||
impl<A> AdamGradientDescent<A, StdRng> {
|
||||
pub fn default(learning_rate: A, iterations: u32) -> Self
|
||||
where
|
||||
A: NumLike,
|
||||
{
|
||||
AdamGradientDescent {
|
||||
base: BaseGradientDescent::new(iterations),
|
||||
adam: AdamHyper {
|
||||
mu: A::zero(),
|
||||
rms: RmsHyper {
|
||||
learning_rate,
|
||||
stabilizer: one_ten_k::<A>() * one_ten_k(),
|
||||
beta: A::one() + -(A::one() / ten()),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, Rng> AdamGradientDescent<A, Rng> {
|
||||
#[must_use]
|
||||
pub fn with_stabilizer(self, stabilizer: A) -> Self {
|
||||
AdamGradientDescent {
|
||||
base: self.base,
|
||||
adam: self.adam.with_stabilizer(stabilizer),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_beta(self, beta: A) -> Self {
|
||||
AdamGradientDescent {
|
||||
base: self.base,
|
||||
adam: self.adam.with_beta(beta),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_mu(self, mu: A) -> Self {
|
||||
AdamGradientDescent {
|
||||
base: self.base,
|
||||
adam: self.adam.with_mu(mu),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_immutable(&self) -> AdamHyper<A>
|
||||
where
|
||||
A: Clone,
|
||||
{
|
||||
self.adam.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, Rng> From<AdamGradientDescent<A, Rng>> for BaseGradientDescent<Rng> {
|
||||
fn from(val: AdamGradientDescent<A, Rng>) -> BaseGradientDescent<Rng> {
|
||||
val.base
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,7 @@
|
||||
use crate::auto_diff::{Differentiable, DifferentiableTagged};
|
||||
use crate::scalar::Scalar;
|
||||
use crate::smooth::smooth;
|
||||
use crate::traits::{NumLike, Sqrt};
|
||||
use crate::traits::NumLike;
|
||||
|
||||
/// A Predictor is a function (`predict`) we're optimising, an `inflate` which adds any metadata
|
||||
/// that the prediction engine might require, a corresponding `deflate` which removes the metadata,
|
||||
@@ -49,6 +49,26 @@ pub struct RmsHyper<A> {
|
||||
pub learning_rate: A,
|
||||
}
|
||||
|
||||
impl<A> RmsHyper<A> {
|
||||
#[must_use]
|
||||
pub fn with_stabilizer(self, s: A) -> RmsHyper<A> {
|
||||
RmsHyper {
|
||||
learning_rate: self.learning_rate,
|
||||
beta: self.beta,
|
||||
stabilizer: s,
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_beta(self, s: A) -> RmsHyper<A> {
|
||||
RmsHyper {
|
||||
learning_rate: self.learning_rate,
|
||||
beta: s,
|
||||
stabilizer: self.stabilizer,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn rms<F, A>(
|
||||
f: F,
|
||||
) -> Predictor<F, DifferentiableTagged<A, A>, Differentiable<A>, RmsHyper<A>>
|
||||
@@ -60,26 +80,22 @@ where
|
||||
inflate: |x| x.map_tag(&mut |()| A::zero()),
|
||||
deflate: |x| x.map_tag(&mut |_| ()),
|
||||
update: |theta, delta, hyper| {
|
||||
DifferentiableTagged::map2_tagged(
|
||||
&theta,
|
||||
delta,
|
||||
&mut |theta, smoothed_grad, delta, ()| {
|
||||
let r = smooth(
|
||||
Scalar::make(hyper.beta.clone()),
|
||||
&Differentiable::of_scalar(Scalar::make(smoothed_grad)),
|
||||
&Differentiable::of_scalar(delta.clone() * delta.clone()),
|
||||
)
|
||||
.into_scalar();
|
||||
let learning_rate = Scalar::make(hyper.learning_rate.clone())
|
||||
/ (r.sqrt() + Scalar::make(hyper.stabilizer.clone()));
|
||||
(
|
||||
(theta.clone()
|
||||
+ -(delta.clone() * Scalar::make(hyper.learning_rate.clone())))
|
||||
.truncate_dual(None),
|
||||
learning_rate.clone_real_part(),
|
||||
)
|
||||
},
|
||||
)
|
||||
DifferentiableTagged::map2_tagged(&theta, delta, &mut |theta, smoothed_r, delta, ()| {
|
||||
let r = smooth(
|
||||
Scalar::make(hyper.beta.clone()),
|
||||
&Differentiable::of_scalar(Scalar::make(smoothed_r)),
|
||||
&Differentiable::of_scalar(delta.clone() * delta.clone()),
|
||||
)
|
||||
.into_scalar();
|
||||
let learning_rate = hyper.learning_rate.clone()
|
||||
/ (r.clone_real_part().sqrt() + hyper.stabilizer.clone());
|
||||
(
|
||||
Scalar::make(
|
||||
theta.clone_real_part() + -(delta.clone_real_part() * learning_rate),
|
||||
),
|
||||
r.clone_real_part(),
|
||||
)
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -109,3 +125,73 @@ where
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AdamHyper<A> {
|
||||
pub rms: RmsHyper<A>,
|
||||
pub mu: A,
|
||||
}
|
||||
|
||||
impl<A> AdamHyper<A> {
|
||||
#[must_use]
|
||||
pub fn with_stabilizer(self, s: A) -> AdamHyper<A> {
|
||||
AdamHyper {
|
||||
mu: self.mu,
|
||||
rms: self.rms.with_stabilizer(s),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_beta(self, s: A) -> AdamHyper<A> {
|
||||
AdamHyper {
|
||||
mu: self.mu,
|
||||
rms: self.rms.with_beta(s),
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_mu(self, mu: A) -> AdamHyper<A> {
|
||||
AdamHyper { mu, rms: self.rms }
|
||||
}
|
||||
}
|
||||
|
||||
type AdamInflated<A> = DifferentiableTagged<A, (A, A)>;
|
||||
|
||||
pub const fn adam<F, A>(f: F) -> Predictor<F, AdamInflated<A>, Differentiable<A>, AdamHyper<A>>
|
||||
where
|
||||
A: NumLike,
|
||||
{
|
||||
Predictor {
|
||||
predict: f,
|
||||
inflate: |x| x.map_tag(&mut |()| (A::zero(), A::zero())),
|
||||
deflate: |x| x.map_tag(&mut |_| ()),
|
||||
update: |theta, delta, hyper| {
|
||||
DifferentiableTagged::map2_tagged(
|
||||
&theta,
|
||||
delta,
|
||||
&mut |theta, (smoothed_velocity, smoothed_r), delta, ()| {
|
||||
let r = smooth(
|
||||
Scalar::make(hyper.rms.beta.clone()),
|
||||
&Differentiable::of_scalar(Scalar::make(smoothed_r)),
|
||||
&Differentiable::of_scalar(delta.clone() * delta.clone()),
|
||||
)
|
||||
.into_scalar();
|
||||
let learning_rate = hyper.rms.learning_rate.clone()
|
||||
/ (r.clone_real_part().sqrt() + hyper.rms.stabilizer.clone());
|
||||
let velocity = smooth(
|
||||
Scalar::make(hyper.mu.clone()),
|
||||
&Differentiable::of_scalar(Scalar::make(smoothed_velocity)),
|
||||
&Differentiable::of_scalar(delta.clone()),
|
||||
)
|
||||
.into_scalar();
|
||||
(
|
||||
Scalar::make(
|
||||
theta.clone_real_part() + -(velocity.clone_real_part() * learning_rate),
|
||||
),
|
||||
(velocity.clone_real_part(), r.clone_real_part()),
|
||||
)
|
||||
},
|
||||
)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@@ -1,7 +1,7 @@
|
||||
use rand::Rng;
|
||||
|
||||
/// Grab `n` random samples from `from_x` and `from_y`, collecting them into a vector.
|
||||
pub fn sample2<R: Rng, T, U, I, J>(rng: &mut R, n: usize, from_x: I, from_y: J) -> (Vec<T>, Vec<U>)
|
||||
pub fn take_2<R: Rng, T, U, I, J>(rng: &mut R, n: usize, from_x: I, from_y: J) -> (Vec<T>, Vec<U>)
|
||||
where
|
||||
T: Copy,
|
||||
U: Copy,
|
||||
|
@@ -116,9 +116,9 @@ mod test_smooth {
|
||||
assert_eq!(
|
||||
output,
|
||||
vec![
|
||||
vec![0.820_000_000_000_000_1, 2.9, 2.2800000000000002],
|
||||
vec![0.820_000_000_000_000_1, 2.9, 2.280_000_000_000_000_2],
|
||||
vec![2.078, 4.43, 6.191_999_999_999_999],
|
||||
vec![1.9802, 4.0169999999999995, 12.302799999999998]
|
||||
vec![1.9802, 4.016_999_999_999_999_5, 12.302_799_999_999_998]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
@@ -60,7 +60,10 @@ fn main() {
|
||||
.map(|x| x.into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
|
||||
assert_eq!(fitted_theta0, [3.985_350_099_342_649, 1.9745945728216352]);
|
||||
assert_eq!(
|
||||
fitted_theta0,
|
||||
[3.985_350_099_342_649, 1.974_594_572_821_635_2]
|
||||
);
|
||||
assert_eq!(fitted_theta1, 6.164_222_983_181_168);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user