Implement Adam, and fix RMS (#22)

This commit is contained in:
Patrick Stevens
2023-05-08 11:39:42 +01:00
committed by GitHub
parent deb0ec67ca
commit a0da79591a
7 changed files with 252 additions and 39 deletions

1
.gitignore vendored
View File

@@ -3,3 +3,4 @@ target/
*.iml
.vscode/
.profile*
.DS_Store

View File

@@ -2,7 +2,7 @@ use crate::auto_diff::{grad, Differentiable, RankedDifferentiable};
use crate::hyper;
use crate::loss::l2_loss_2;
use crate::predictor::Predictor;
use crate::sample::sample2;
use crate::sample;
use crate::traits::NumLike;
use rand::Rng;
use std::hash::Hash;
@@ -105,7 +105,7 @@ where
),
)]),
Some((rng, batch_size)) => {
let (sampled_xs, sampled_ys) = sample2(rng, *batch_size, xs, ys);
let (sampled_xs, sampled_ys) = sample::take_2(rng, *batch_size, xs, ys);
RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(
l2_loss_2(
&predictor.predict,
@@ -369,7 +369,7 @@ mod tests {
fn test_with_rms() {
let beta = NotNan::new(0.9).expect("not nan");
let stabilizer = NotNan::new(0.00000001).expect("not nan");
let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.001).expect("not nan"), 3000)
let hyper = hyper::RmsGradientDescent::default(NotNan::new(0.01).expect("not nan"), 3000)
.with_stabilizer(stabilizer)
.with_beta(beta);
@@ -404,7 +404,56 @@ mod tests {
.map(|x| x.into_inner())
.collect::<Vec<_>>();
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
assert_eq!(fitted_theta0, [3.985_350_099_342_649, 1.9745945728216352]);
assert_eq!(fitted_theta1, 6.164_222_983_181_168);
assert_eq!(fitted_theta0, [3.9746454441720851, 1.9714549220774951]);
assert_eq!(fitted_theta1, 6.1645790482740361);
}
#[test]
fn test_with_adam() {
let beta = NotNan::new(0.9).expect("not nan");
let stabilizer = NotNan::new(0.00000001).expect("not nan");
let mu = NotNan::new(0.85).expect("not nan");
// Erratum in the book: they printed 0.001 but intended 0.01.
let hyper = hyper::AdamGradientDescent::default(NotNan::new(0.01).expect("not nan"), 1500)
.with_stabilizer(stabilizer)
.with_beta(beta)
.with_mu(mu);
let iterated = {
let xs = to_not_nan_2(PLANE_XS);
let ys = to_not_nan_1(PLANE_YS);
let zero_params = [
RankedDifferentiable::of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()])
.to_unranked(),
Differentiable::of_scalar(Scalar::zero()),
];
gradient_descent(
hyper,
&xs,
RankedDifferentiableTagged::of_slice_2::<_, 2>,
&ys,
zero_params,
predictor::adam(predict_plane),
hyper::AdamGradientDescent::to_immutable,
)
};
let [theta0, theta1] = iterated;
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
let fitted_theta0 = theta0
.collect()
.iter()
.map(|x| x.into_inner())
.collect::<Vec<_>>();
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
assert_eq!(
fitted_theta0,
[3.980_262_420_345_729_5, 1.977_071_898_301_443_9]
);
assert_eq!(fitted_theta1, 6.170_196_024_282_712_5);
}
}

View File

@@ -1,4 +1,4 @@
use crate::predictor::{NakedHypers, RmsHyper, VelocityHypers};
use crate::predictor::{AdamHyper, NakedHypers, RmsHyper, VelocityHypers};
use crate::traits::{NumLike, Zero};
use rand::rngs::StdRng;
@@ -124,6 +124,23 @@ impl<A, Rng> From<VelocityGradientDescent<A, Rng>> for BaseGradientDescent<Rng>
}
}
fn ten<A>() -> A
where
A: NumLike,
{
let two = A::one() + A::one();
two.clone() * two.clone() * two.clone() + two
}
fn one_ten_k<A>() -> A
where
A: NumLike,
{
let one_tenth = A::one() / ten();
let one_hundredth = one_tenth.clone() * one_tenth;
one_hundredth.clone() * one_hundredth
}
pub struct RmsGradientDescent<A, Rng> {
base: BaseGradientDescent<Rng>,
rms: RmsHyper<A>,
@@ -134,17 +151,11 @@ impl<A> RmsGradientDescent<A, StdRng> {
where
A: NumLike,
{
let two = A::one() + A::one();
let ten = two.clone() * two.clone() * two.clone() + two;
let one_tenth = A::one() / ten.clone();
let one_hundredth = one_tenth.clone() * one_tenth;
let one_ten_k = one_hundredth.clone() * one_hundredth;
RmsGradientDescent {
base: BaseGradientDescent::new(iterations),
rms: RmsHyper {
stabilizer: one_ten_k.clone() * one_ten_k,
beta: A::one() + -(A::one() / ten),
stabilizer: one_ten_k::<A>() * one_ten_k(),
beta: A::one() + -(A::one() / ten()),
learning_rate,
},
}
@@ -189,3 +200,66 @@ impl<A, Rng> From<RmsGradientDescent<A, Rng>> for BaseGradientDescent<Rng> {
val.base
}
}
pub struct AdamGradientDescent<A, Rng> {
base: BaseGradientDescent<Rng>,
adam: AdamHyper<A>,
}
impl<A> AdamGradientDescent<A, StdRng> {
pub fn default(learning_rate: A, iterations: u32) -> Self
where
A: NumLike,
{
AdamGradientDescent {
base: BaseGradientDescent::new(iterations),
adam: AdamHyper {
mu: A::zero(),
rms: RmsHyper {
learning_rate,
stabilizer: one_ten_k::<A>() * one_ten_k(),
beta: A::one() + -(A::one() / ten()),
},
},
}
}
}
impl<A, Rng> AdamGradientDescent<A, Rng> {
#[must_use]
pub fn with_stabilizer(self, stabilizer: A) -> Self {
AdamGradientDescent {
base: self.base,
adam: self.adam.with_stabilizer(stabilizer),
}
}
#[must_use]
pub fn with_beta(self, beta: A) -> Self {
AdamGradientDescent {
base: self.base,
adam: self.adam.with_beta(beta),
}
}
#[must_use]
pub fn with_mu(self, mu: A) -> Self {
AdamGradientDescent {
base: self.base,
adam: self.adam.with_mu(mu),
}
}
pub fn to_immutable(&self) -> AdamHyper<A>
where
A: Clone,
{
self.adam.clone()
}
}
impl<A, Rng> From<AdamGradientDescent<A, Rng>> for BaseGradientDescent<Rng> {
fn from(val: AdamGradientDescent<A, Rng>) -> BaseGradientDescent<Rng> {
val.base
}
}

View File

@@ -1,7 +1,7 @@
use crate::auto_diff::{Differentiable, DifferentiableTagged};
use crate::scalar::Scalar;
use crate::smooth::smooth;
use crate::traits::{NumLike, Sqrt};
use crate::traits::NumLike;
/// A Predictor is a function (`predict`) we're optimising, an `inflate` which adds any metadata
/// that the prediction engine might require, a corresponding `deflate` which removes the metadata,
@@ -49,6 +49,26 @@ pub struct RmsHyper<A> {
pub learning_rate: A,
}
impl<A> RmsHyper<A> {
#[must_use]
pub fn with_stabilizer(self, s: A) -> RmsHyper<A> {
RmsHyper {
learning_rate: self.learning_rate,
beta: self.beta,
stabilizer: s,
}
}
#[must_use]
pub fn with_beta(self, s: A) -> RmsHyper<A> {
RmsHyper {
learning_rate: self.learning_rate,
beta: s,
stabilizer: self.stabilizer,
}
}
}
pub const fn rms<F, A>(
f: F,
) -> Predictor<F, DifferentiableTagged<A, A>, Differentiable<A>, RmsHyper<A>>
@@ -60,26 +80,22 @@ where
inflate: |x| x.map_tag(&mut |()| A::zero()),
deflate: |x| x.map_tag(&mut |_| ()),
update: |theta, delta, hyper| {
DifferentiableTagged::map2_tagged(
&theta,
delta,
&mut |theta, smoothed_grad, delta, ()| {
let r = smooth(
Scalar::make(hyper.beta.clone()),
&Differentiable::of_scalar(Scalar::make(smoothed_grad)),
&Differentiable::of_scalar(delta.clone() * delta.clone()),
)
.into_scalar();
let learning_rate = Scalar::make(hyper.learning_rate.clone())
/ (r.sqrt() + Scalar::make(hyper.stabilizer.clone()));
(
(theta.clone()
+ -(delta.clone() * Scalar::make(hyper.learning_rate.clone())))
.truncate_dual(None),
learning_rate.clone_real_part(),
)
},
)
DifferentiableTagged::map2_tagged(&theta, delta, &mut |theta, smoothed_r, delta, ()| {
let r = smooth(
Scalar::make(hyper.beta.clone()),
&Differentiable::of_scalar(Scalar::make(smoothed_r)),
&Differentiable::of_scalar(delta.clone() * delta.clone()),
)
.into_scalar();
let learning_rate = hyper.learning_rate.clone()
/ (r.clone_real_part().sqrt() + hyper.stabilizer.clone());
(
Scalar::make(
theta.clone_real_part() + -(delta.clone_real_part() * learning_rate),
),
r.clone_real_part(),
)
})
},
}
}
@@ -109,3 +125,73 @@ where
},
}
}
#[derive(Clone)]
pub struct AdamHyper<A> {
pub rms: RmsHyper<A>,
pub mu: A,
}
impl<A> AdamHyper<A> {
#[must_use]
pub fn with_stabilizer(self, s: A) -> AdamHyper<A> {
AdamHyper {
mu: self.mu,
rms: self.rms.with_stabilizer(s),
}
}
#[must_use]
pub fn with_beta(self, s: A) -> AdamHyper<A> {
AdamHyper {
mu: self.mu,
rms: self.rms.with_beta(s),
}
}
#[must_use]
pub fn with_mu(self, mu: A) -> AdamHyper<A> {
AdamHyper { mu, rms: self.rms }
}
}
type AdamInflated<A> = DifferentiableTagged<A, (A, A)>;
pub const fn adam<F, A>(f: F) -> Predictor<F, AdamInflated<A>, Differentiable<A>, AdamHyper<A>>
where
A: NumLike,
{
Predictor {
predict: f,
inflate: |x| x.map_tag(&mut |()| (A::zero(), A::zero())),
deflate: |x| x.map_tag(&mut |_| ()),
update: |theta, delta, hyper| {
DifferentiableTagged::map2_tagged(
&theta,
delta,
&mut |theta, (smoothed_velocity, smoothed_r), delta, ()| {
let r = smooth(
Scalar::make(hyper.rms.beta.clone()),
&Differentiable::of_scalar(Scalar::make(smoothed_r)),
&Differentiable::of_scalar(delta.clone() * delta.clone()),
)
.into_scalar();
let learning_rate = hyper.rms.learning_rate.clone()
/ (r.clone_real_part().sqrt() + hyper.rms.stabilizer.clone());
let velocity = smooth(
Scalar::make(hyper.mu.clone()),
&Differentiable::of_scalar(Scalar::make(smoothed_velocity)),
&Differentiable::of_scalar(delta.clone()),
)
.into_scalar();
(
Scalar::make(
theta.clone_real_part() + -(velocity.clone_real_part() * learning_rate),
),
(velocity.clone_real_part(), r.clone_real_part()),
)
},
)
},
}
}

View File

@@ -1,7 +1,7 @@
use rand::Rng;
/// Grab `n` random samples from `from_x` and `from_y`, collecting them into a vector.
pub fn sample2<R: Rng, T, U, I, J>(rng: &mut R, n: usize, from_x: I, from_y: J) -> (Vec<T>, Vec<U>)
pub fn take_2<R: Rng, T, U, I, J>(rng: &mut R, n: usize, from_x: I, from_y: J) -> (Vec<T>, Vec<U>)
where
T: Copy,
U: Copy,

View File

@@ -116,9 +116,9 @@ mod test_smooth {
assert_eq!(
output,
vec![
vec![0.820_000_000_000_000_1, 2.9, 2.2800000000000002],
vec![0.820_000_000_000_000_1, 2.9, 2.280_000_000_000_000_2],
vec![2.078, 4.43, 6.191_999_999_999_999],
vec![1.9802, 4.0169999999999995, 12.302799999999998]
vec![1.9802, 4.016_999_999_999_999_5, 12.302_799_999_999_998]
]
);
}

View File

@@ -60,7 +60,10 @@ fn main() {
.map(|x| x.into_inner())
.collect::<Vec<_>>();
let fitted_theta1 = theta1.to_scalar().real_part().into_inner();
assert_eq!(fitted_theta0, [3.985_350_099_342_649, 1.9745945728216352]);
assert_eq!(
fitted_theta0,
[3.985_350_099_342_649, 1.974_594_572_821_635_2]
);
assert_eq!(fitted_theta1, 6.164_222_983_181_168);
}