Deduplicate scalars (#7)
This commit is contained in:
@@ -12,7 +12,7 @@ where
|
||||
A: Zero,
|
||||
{
|
||||
fn zero() -> DifferentiableHidden<A> {
|
||||
DifferentiableHidden::Scalar(Scalar::Number(A::zero()))
|
||||
DifferentiableHidden::Scalar(Scalar::Number(A::zero(), None))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ where
|
||||
A: One,
|
||||
{
|
||||
fn one() -> Scalar<A> {
|
||||
Scalar::Number(A::one())
|
||||
Scalar::Number(A::one(), None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum DifferentiableHidden<A> {
|
||||
Scalar(Scalar<A>),
|
||||
Vector(Vec<DifferentiableHidden<A>>),
|
||||
@@ -71,9 +72,9 @@ where
|
||||
}
|
||||
|
||||
impl<A> DifferentiableHidden<A> {
|
||||
fn map<B, F>(&self, f: &F) -> DifferentiableHidden<B>
|
||||
fn map<B, F>(&self, f: &mut F) -> DifferentiableHidden<B>
|
||||
where
|
||||
F: Fn(Scalar<A>) -> Scalar<B>,
|
||||
F: FnMut(Scalar<A>) -> Scalar<B>,
|
||||
A: Clone,
|
||||
{
|
||||
match self {
|
||||
@@ -114,7 +115,7 @@ impl<A> DifferentiableHidden<A> {
|
||||
DifferentiableHidden::Vector(
|
||||
input
|
||||
.iter()
|
||||
.map(|v| DifferentiableHidden::Scalar(Scalar::Number((*v).clone())))
|
||||
.map(|v| DifferentiableHidden::Scalar(Scalar::Number((*v).clone(), None)))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
@@ -131,7 +132,8 @@ where
|
||||
+ Div<Output = A>
|
||||
+ Zero
|
||||
+ One
|
||||
+ Neg<Output = A>,
|
||||
+ Neg<Output = A>
|
||||
+ Display,
|
||||
{
|
||||
fn accumulate_gradients_vec(v: &[DifferentiableHidden<A>], acc: &mut HashMap<Scalar<A>, A>) {
|
||||
for v in v.iter().rev() {
|
||||
@@ -155,14 +157,14 @@ where
|
||||
let mut acc = HashMap::new();
|
||||
self.accumulate_gradients(&mut acc);
|
||||
|
||||
wrt.map(&|d| match acc.get(&d) {
|
||||
None => Scalar::Number(A::zero()),
|
||||
Some(x) => Scalar::Number(x.clone()),
|
||||
wrt.map(&mut |d| match acc.get(&d) {
|
||||
None => Scalar::Number(A::zero(), None),
|
||||
Some(x) => Scalar::Number(x.clone(), None),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Differentiable<A, const RANK: usize> {
|
||||
contents: DifferentiableHidden<A>,
|
||||
}
|
||||
@@ -205,9 +207,9 @@ impl<A, const RANK: usize> Differentiable<A, RANK> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map<B, F>(s: Differentiable<A, RANK>, f: &F) -> Differentiable<B, RANK>
|
||||
pub fn map<B, F>(s: Differentiable<A, RANK>, f: &mut F) -> Differentiable<B, RANK>
|
||||
where
|
||||
F: Fn(Scalar<A>) -> Scalar<B>,
|
||||
F: FnMut(Scalar<A>) -> Scalar<B>,
|
||||
A: Clone,
|
||||
{
|
||||
Differentiable {
|
||||
@@ -252,9 +254,15 @@ impl<A, const RANK: usize> Differentiable<A, RANK> {
|
||||
+ Zero
|
||||
+ One
|
||||
+ Neg<Output = A>
|
||||
+ Eq,
|
||||
+ Eq
|
||||
+ std::fmt::Display,
|
||||
{
|
||||
let wrt = theta.contents.map(&Scalar::truncate_dual);
|
||||
let mut i = 0usize;
|
||||
let wrt = theta.contents.map(&mut |x| {
|
||||
let result = Scalar::truncate_dual(x, i);
|
||||
i += 1;
|
||||
result
|
||||
});
|
||||
let after_f = f(Differentiable {
|
||||
contents: wrt.clone(),
|
||||
});
|
||||
@@ -268,6 +276,8 @@ impl<A, const RANK: usize> Differentiable<A, RANK> {
|
||||
mod tests {
|
||||
use ordered_float::NotNan;
|
||||
|
||||
use crate::loss::{l2_loss_2, predict_line_2};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn extract_scalar<'a, A>(d: &'a DifferentiableHidden<A>) -> &'a A {
|
||||
@@ -283,15 +293,17 @@ mod tests {
|
||||
vec![
|
||||
DifferentiableHidden::Scalar(Scalar::Number(
|
||||
NotNan::new(3.0).expect("3 is not NaN"),
|
||||
Some(0usize),
|
||||
)),
|
||||
DifferentiableHidden::Scalar(Scalar::Number(
|
||||
NotNan::new(4.0).expect("4 is not NaN"),
|
||||
Some(1usize),
|
||||
)),
|
||||
]
|
||||
.into(),
|
||||
);
|
||||
let mapped = v.map(&|x: Scalar<NotNan<f64>>| match x {
|
||||
Scalar::Number(i) => Scalar::Number(i + NotNan::new(1.0).expect("1 is not NaN")),
|
||||
let mapped = v.map(&mut |x: Scalar<NotNan<f64>>| match x {
|
||||
Scalar::Number(i, n) => Scalar::Number(i + NotNan::new(1.0).expect("1 is not NaN"), n),
|
||||
Scalar::Dual(_, _) => panic!("Not hit"),
|
||||
});
|
||||
|
||||
@@ -305,4 +317,29 @@ mod tests {
|
||||
|
||||
assert_eq!(v, [4.0, 5.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autodiff() {
|
||||
let input_vec = of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()]);
|
||||
let xs = [2.0, 1.0, 4.0, 3.0].map(|x| NotNan::new(x).expect("not nan"));
|
||||
let ys = [1.8, 1.2, 4.2, 3.3].map(|x| NotNan::new(x).expect("not nan"));
|
||||
let grad = Differentiable::grad(
|
||||
|x| {
|
||||
Differentiable::of_vector(vec![of_scalar(l2_loss_2(
|
||||
predict_line_2,
|
||||
of_slice(&xs),
|
||||
of_slice(&ys),
|
||||
x,
|
||||
))])
|
||||
},
|
||||
input_vec,
|
||||
);
|
||||
|
||||
let grad_vec: Vec<f64> = Differentiable::to_vector(grad)
|
||||
.into_iter()
|
||||
.map(to_scalar)
|
||||
.map(|x| f64::from(*x.real_part()))
|
||||
.collect();
|
||||
assert_eq!(grad_vec, vec![-63.0, -21.0]);
|
||||
}
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@
|
||||
|
||||
pub mod auto_diff;
|
||||
pub mod expr_syntax_tree;
|
||||
pub mod loss;
|
||||
pub mod scalar;
|
||||
pub mod tensor;
|
||||
pub mod traits;
|
||||
|
93
little_learner/src/loss.rs
Normal file
93
little_learner/src/loss.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use std::{
|
||||
iter::Sum,
|
||||
ops::{Add, Mul, Neg},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
auto_diff::{of_scalar, to_scalar, Differentiable},
|
||||
scalar::Scalar,
|
||||
traits::{One, Zero},
|
||||
};
|
||||
|
||||
pub fn square<A>(x: &A) -> A
|
||||
where
|
||||
A: Mul<Output = A> + Clone,
|
||||
{
|
||||
x.clone() * x.clone()
|
||||
}
|
||||
|
||||
pub fn dot_2<A, const RANK: usize>(
|
||||
x: &Differentiable<A, RANK>,
|
||||
y: &Differentiable<A, RANK>,
|
||||
) -> Differentiable<A, RANK>
|
||||
where
|
||||
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Copy + Default,
|
||||
{
|
||||
Differentiable::map2(x, y, &|x, y| x.clone() * y.clone())
|
||||
}
|
||||
|
||||
fn squared_2<A, const RANK: usize>(x: &Differentiable<A, RANK>) -> Differentiable<A, RANK>
|
||||
where
|
||||
A: Mul<Output = A> + Copy + Default,
|
||||
{
|
||||
Differentiable::map2(x, x, &|x, y| x.clone() * y.clone())
|
||||
}
|
||||
|
||||
fn sum_2<A>(x: Differentiable<A, 1>) -> Scalar<A>
|
||||
where
|
||||
A: Sum<A> + Copy + Add<Output = A> + Zero,
|
||||
{
|
||||
Differentiable::to_vector(x)
|
||||
.into_iter()
|
||||
.map(to_scalar)
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn l2_norm_2<A>(prediction: &Differentiable<A, 1>, data: &Differentiable<A, 1>) -> Scalar<A>
|
||||
where
|
||||
A: Sum<A> + Mul<Output = A> + Copy + Default + Neg<Output = A> + Add<Output = A> + Zero + Neg,
|
||||
{
|
||||
let diff = Differentiable::map2(prediction, data, &|x, y| x.clone() - y.clone());
|
||||
sum_2(squared_2(&diff))
|
||||
}
|
||||
|
||||
pub fn l2_loss_2<A, F, Params>(
|
||||
target: F,
|
||||
data_xs: Differentiable<A, 1>,
|
||||
data_ys: Differentiable<A, 1>,
|
||||
params: Params,
|
||||
) -> Scalar<A>
|
||||
where
|
||||
F: Fn(Differentiable<A, 1>, Params) -> Differentiable<A, 1>,
|
||||
A: Sum<A> + Mul<Output = A> + Copy + Default + Neg<Output = A> + Add<Output = A> + Zero,
|
||||
{
|
||||
let pred_ys = target(data_xs, params);
|
||||
l2_norm_2(&pred_ys, &data_ys)
|
||||
}
|
||||
|
||||
pub fn predict_line_2<A>(
|
||||
xs: Differentiable<A, 1>,
|
||||
theta: Differentiable<A, 1>,
|
||||
) -> Differentiable<A, 1>
|
||||
where
|
||||
A: Mul<Output = A> + Add<Output = A> + Sum<<A as Mul>::Output> + Copy + Default + One + Zero,
|
||||
{
|
||||
let xs = Differentiable::to_vector(xs)
|
||||
.into_iter()
|
||||
.map(|v| to_scalar(v));
|
||||
let mut result = vec![];
|
||||
for x in xs {
|
||||
let left_arg = Differentiable::of_vector(vec![
|
||||
of_scalar(x.clone()),
|
||||
of_scalar(<Scalar<A> as One>::one()),
|
||||
]);
|
||||
let dotted = of_scalar(
|
||||
Differentiable::to_vector(dot_2(&left_arg, &theta))
|
||||
.iter()
|
||||
.map(|x| to_scalar((*x).clone()))
|
||||
.sum(),
|
||||
);
|
||||
result.push(dotted);
|
||||
}
|
||||
Differentiable::of_vector(result)
|
||||
}
|
@@ -7,7 +7,7 @@ use std::{
|
||||
ops::{Add, AddAssign, Div, Mul, Neg, Sub},
|
||||
};
|
||||
|
||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
|
||||
pub enum LinkData<A> {
|
||||
Addition(Box<Scalar<A>>, Box<Scalar<A>>),
|
||||
Neg(Box<Scalar<A>>),
|
||||
@@ -16,9 +16,9 @@ pub enum LinkData<A> {
|
||||
Log(Box<Scalar<A>>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
|
||||
pub enum Link<A> {
|
||||
EndOfLink,
|
||||
EndOfLink(Option<usize>),
|
||||
Link(LinkData<A>),
|
||||
}
|
||||
|
||||
@@ -28,7 +28,8 @@ where
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Link::EndOfLink => f.write_str("<end>"),
|
||||
Link::EndOfLink(Some(i)) => f.write_fmt(format_args!("<end {}>", *i)),
|
||||
Link::EndOfLink(None) => f.write_str("<end>"),
|
||||
Link::Link(LinkData::Addition(left, right)) => {
|
||||
f.write_fmt(format_args!("({} + {})", left.as_ref(), right.as_ref()))
|
||||
}
|
||||
@@ -59,7 +60,7 @@ impl<A> Link<A> {
|
||||
+ One,
|
||||
{
|
||||
match self {
|
||||
Link::EndOfLink => match acc.entry(d.clone()) {
|
||||
Link::EndOfLink(_) => match acc.entry(d.clone()) {
|
||||
Entry::Occupied(mut o) => {
|
||||
let entry = o.get_mut();
|
||||
*entry += z;
|
||||
@@ -113,9 +114,9 @@ impl<A> Link<A> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||
#[derive(Clone, Hash, PartialEq, Eq, Debug)]
|
||||
pub enum Scalar<A> {
|
||||
Number(A),
|
||||
Number(A, Option<usize>),
|
||||
// The value, and the link.
|
||||
Dual(A, Link<A>),
|
||||
}
|
||||
@@ -125,7 +126,7 @@ where
|
||||
A: Zero,
|
||||
{
|
||||
fn zero() -> Self {
|
||||
Scalar::Number(A::zero())
|
||||
Scalar::Number(A::zero(), None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -198,7 +199,7 @@ where
|
||||
impl<A> Scalar<A> {
|
||||
pub fn real_part(&self) -> &A {
|
||||
match self {
|
||||
Scalar::Number(a) => a,
|
||||
Scalar::Number(a, _) => a,
|
||||
Scalar::Dual(a, _) => a,
|
||||
}
|
||||
}
|
||||
@@ -208,7 +209,7 @@ impl<A> Scalar<A> {
|
||||
A: Clone,
|
||||
{
|
||||
match self {
|
||||
Scalar::Number(a) => (*a).clone(),
|
||||
Scalar::Number(a, _) => (*a).clone(),
|
||||
Scalar::Dual(a, _) => (*a).clone(),
|
||||
}
|
||||
}
|
||||
@@ -216,7 +217,7 @@ impl<A> Scalar<A> {
|
||||
pub fn link(self) -> Link<A> {
|
||||
match self {
|
||||
Scalar::Dual(_, link) => link,
|
||||
Scalar::Number(_) => Link::EndOfLink,
|
||||
Scalar::Number(_, i) => Link::EndOfLink(i),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -226,15 +227,15 @@ impl<A> Scalar<A> {
|
||||
{
|
||||
match self {
|
||||
Scalar::Dual(_, data) => data.clone(),
|
||||
Scalar::Number(_) => Link::EndOfLink,
|
||||
Scalar::Number(_, i) => Link::EndOfLink(*i),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn truncate_dual(self) -> Scalar<A>
|
||||
pub fn truncate_dual(self, index: usize) -> Scalar<A>
|
||||
where
|
||||
A: Clone,
|
||||
{
|
||||
Scalar::Dual(self.clone_real_part(), Link::EndOfLink)
|
||||
Scalar::Dual(self.clone_real_part(), Link::EndOfLink(Some(index)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,8 +245,9 @@ where
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Scalar::Number(n) => f.write_fmt(format_args!("{}", n)),
|
||||
Scalar::Dual(n, link) => f.write_fmt(format_args!("{}, link: {}", n, link)),
|
||||
Scalar::Number(n, Some(index)) => f.write_fmt(format_args!("{}_{}", n, index)),
|
||||
Scalar::Number(n, None) => f.write_fmt(format_args!("{}", n)),
|
||||
Scalar::Dual(n, link) => f.write_fmt(format_args!("<{}, link: {}>", n, link)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -3,110 +3,53 @@
|
||||
|
||||
mod with_tensor;
|
||||
|
||||
use little_learner::auto_diff::{of_scalar, of_slice, to_scalar, Differentiable};
|
||||
use little_learner::scalar::Scalar;
|
||||
use little_learner::traits::{One, Zero};
|
||||
use ordered_float::NotNan;
|
||||
use little_learner::auto_diff::{of_scalar, of_slice, Differentiable};
|
||||
|
||||
use std::iter::Sum;
|
||||
use std::ops::{Add, Mul, Neg};
|
||||
use little_learner::loss::{l2_loss_2, predict_line_2, square};
|
||||
use little_learner::traits::Zero;
|
||||
use ordered_float::NotNan;
|
||||
|
||||
use crate::with_tensor::{l2_loss, predict_line};
|
||||
|
||||
fn dot_2<A, const RANK: usize>(
|
||||
x: &Differentiable<A, RANK>,
|
||||
y: &Differentiable<A, RANK>,
|
||||
) -> Differentiable<A, RANK>
|
||||
where
|
||||
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Copy + Default,
|
||||
{
|
||||
Differentiable::map2(x, y, &|x, y| x.clone() * y.clone())
|
||||
}
|
||||
|
||||
fn squared_2<A, const RANK: usize>(x: &Differentiable<A, RANK>) -> Differentiable<A, RANK>
|
||||
where
|
||||
A: Mul<Output = A> + Copy + Default,
|
||||
{
|
||||
Differentiable::map2(x, x, &|x, y| x.clone() * y.clone())
|
||||
}
|
||||
|
||||
fn sum_2<A>(x: Differentiable<A, 1>) -> Scalar<A>
|
||||
where
|
||||
A: Sum<A> + Copy + Add<Output = A> + Zero,
|
||||
{
|
||||
Differentiable::to_vector(x)
|
||||
.into_iter()
|
||||
.map(to_scalar)
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn l2_norm_2<A>(prediction: &Differentiable<A, 1>, data: &Differentiable<A, 1>) -> Scalar<A>
|
||||
where
|
||||
A: Sum<A> + Mul<Output = A> + Copy + Default + Neg<Output = A> + Add<Output = A> + Zero + Neg,
|
||||
{
|
||||
let diff = Differentiable::map2(prediction, data, &|x, y| x.clone() - y.clone());
|
||||
sum_2(squared_2(&diff))
|
||||
}
|
||||
|
||||
pub fn l2_loss_2<A, F, Params>(
|
||||
target: F,
|
||||
data_xs: Differentiable<A, 1>,
|
||||
data_ys: Differentiable<A, 1>,
|
||||
params: Params,
|
||||
) -> Scalar<A>
|
||||
where
|
||||
F: Fn(Differentiable<A, 1>, Params) -> Differentiable<A, 1>,
|
||||
A: Sum<A> + Mul<Output = A> + Copy + Default + Neg<Output = A> + Add<Output = A> + Zero,
|
||||
{
|
||||
let pred_ys = target(data_xs, params);
|
||||
l2_norm_2(&pred_ys, &data_ys)
|
||||
}
|
||||
|
||||
fn predict_line_2<A>(xs: Differentiable<A, 1>, theta: Differentiable<A, 1>) -> Differentiable<A, 1>
|
||||
where
|
||||
A: Mul<Output = A> + Add<Output = A> + Sum<<A as Mul>::Output> + Copy + Default + One + Zero,
|
||||
{
|
||||
let xs = Differentiable::to_vector(xs)
|
||||
.into_iter()
|
||||
.map(|v| to_scalar(v));
|
||||
let mut result = vec![];
|
||||
for x in xs {
|
||||
let left_arg = Differentiable::of_vector(vec![
|
||||
of_scalar(x.clone()),
|
||||
of_scalar(<Scalar<A> as One>::one()),
|
||||
]);
|
||||
let dotted = Differentiable::to_vector(dot_2(&left_arg, &theta));
|
||||
result.push(dotted[0].clone());
|
||||
}
|
||||
Differentiable::of_vector(result)
|
||||
}
|
||||
|
||||
fn square<A>(x: &A) -> A
|
||||
where
|
||||
A: Mul<Output = A> + Clone,
|
||||
{
|
||||
x.clone() * x.clone()
|
||||
#[allow(dead_code)]
|
||||
fn l2_loss_non_autodiff_example() {
|
||||
let xs = [2.0, 1.0, 4.0, 3.0];
|
||||
let ys = [1.8, 1.2, 4.2, 3.3];
|
||||
let loss = l2_loss(predict_line, &xs, &ys, &[0.0099, 0.0]);
|
||||
println!("{:?}", loss);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let loss = l2_loss(
|
||||
predict_line,
|
||||
&[2.0, 1.0, 4.0, 3.0],
|
||||
&[1.8, 1.2, 4.2, 3.3],
|
||||
&[0.0099, 0.0],
|
||||
);
|
||||
println!("{:?}", loss);
|
||||
let input_vec = of_slice(&[NotNan::new(27.0).expect("not nan")]);
|
||||
|
||||
let grad = Differentiable::grad(|x| Differentiable::map(x, &mut |x| square(&x)), input_vec);
|
||||
println!("Gradient of the x^2 function at x=27: {}", grad);
|
||||
|
||||
let xs = [2.0, 1.0, 4.0, 3.0];
|
||||
let ys = [1.8, 1.2, 4.2, 3.3];
|
||||
|
||||
let loss = l2_loss_2(
|
||||
predict_line_2,
|
||||
of_slice(&[2.0, 1.0, 4.0, 3.0]),
|
||||
of_slice(&[1.8, 1.2, 4.2, 3.3]),
|
||||
of_slice(&[0.0099, 0.0]),
|
||||
of_slice(&xs),
|
||||
of_slice(&ys),
|
||||
of_slice(&[0.0, 0.0]),
|
||||
);
|
||||
println!("{}", loss);
|
||||
println!("Computation of L2 loss: {}", loss);
|
||||
|
||||
let input_vec = of_slice(&[NotNan::new(27.0).expect("not nan")]);
|
||||
let input_vec = of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()]);
|
||||
let xs = [2.0, 1.0, 4.0, 3.0].map(|x| NotNan::new(x).expect("not nan"));
|
||||
let ys = [1.8, 1.2, 4.2, 3.3].map(|x| NotNan::new(x).expect("not nan"));
|
||||
let grad = Differentiable::grad(
|
||||
|x| {
|
||||
Differentiable::of_vector(vec![of_scalar(l2_loss_2(
|
||||
predict_line_2,
|
||||
of_slice(&xs),
|
||||
of_slice(&ys),
|
||||
x,
|
||||
))])
|
||||
},
|
||||
input_vec,
|
||||
);
|
||||
|
||||
let grad = Differentiable::grad(|x| Differentiable::map(x, &|x| square(&x)), input_vec);
|
||||
println!("{}", grad);
|
||||
}
|
||||
|
Reference in New Issue
Block a user