Add ext form of relu (#27)

This commit is contained in:
Patrick Stevens
2023-06-17 15:46:19 +01:00
committed by GitHub
parent 242f71fa75
commit 5bb1bddf83
3 changed files with 170 additions and 26 deletions

View File

@@ -3,7 +3,7 @@ use crate::loss::dot;
use crate::scalar::Scalar;
use crate::traits::{NumLike, Zero};
fn rectify<A>(x: A) -> A
pub(crate) fn rectify<A>(x: A) -> A
where
A: Zero + PartialOrd,
{

View File

@@ -1,6 +1,11 @@
use crate::auto_diff::{Differentiable, DifferentiableTagged, RankedDifferentiable};
use crate::auto_diff::{
Differentiable, DifferentiableTagged, RankedDifferentiable, RankedDifferentiableTagged,
};
use crate::decider::rectify;
use crate::scalar::Scalar;
use crate::traits::{NumLike, Zero};
use std::iter::Sum;
use std::ops::Mul;
use std::ops::{Add, Mul};
pub fn ext1<A, B, Tag, Tag2, F>(
n: usize,
@@ -49,17 +54,21 @@ where
}
}
pub fn elementwise_mul_via_ext<A, const RANK1: usize, const RANK2: usize>(
x: &RankedDifferentiable<A, RANK1>,
y: &RankedDifferentiable<A, RANK2>,
pub fn elementwise_mul_via_ext<A, Tag, Tag2, const RANK1: usize, const RANK2: usize>(
x: &RankedDifferentiableTagged<A, Tag, RANK1>,
y: &RankedDifferentiableTagged<A, Tag2, RANK2>,
) -> RankedDifferentiable<A, RANK1>
where
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Clone + Default,
Tag: Clone,
Tag2: Clone,
{
ext2(
0,
0,
&mut |x, y| Differentiable::of_scalar(x.clone().into_scalar() * y.clone().into_scalar()),
&mut |x, y| {
DifferentiableTagged::of_scalar(x.borrow_scalar().clone() * y.borrow_scalar().clone())
},
x.to_unranked_borrow(),
y.to_unranked_borrow(),
)
@@ -67,11 +76,18 @@ where
.unwrap()
}
/// Produce the matrix multiplication of the inputs, threading where necessary until the
/// Produce the element-wise multiplication of the inputs, threading where necessary until the
/// first argument has rank 2 and the second argument has rank 1.
pub fn star_2_1<T>(x: &Differentiable<T>, y: &Differentiable<T>) -> Differentiable<T>
/// This is essentially "matrix-multiply a matrix by a vector, but don't do the sum; instead
/// leave the components to be summed in a vector".
pub fn star_2_1<T, Tag, Tag2>(
x: &DifferentiableTagged<T, Tag>,
y: &DifferentiableTagged<T, Tag2>,
) -> Differentiable<T>
where
T: Clone + Sum + Mul<Output = T> + Default,
Tag: Clone,
Tag2: Clone,
{
ext2(
2,
@@ -88,16 +104,95 @@ where
)
}
fn sum_1_scalar<A, Tag>(x: RankedDifferentiableTagged<A, Tag, 1>) -> Scalar<A>
where
A: Sum<A> + Clone + Add<Output = A> + Zero,
{
RankedDifferentiableTagged::to_vector(x)
.into_iter()
.map(|x| x.to_scalar())
.sum()
}
pub fn sum_1<A, Tag>(x: RankedDifferentiableTagged<A, Tag, 1>) -> Differentiable<A>
where
A: Sum<A> + Clone + Add<Output = A> + Zero,
{
DifferentiableTagged::of_scalar(sum_1_scalar(x))
}
pub fn sum<T>(x: &Differentiable<T>) -> Differentiable<T>
where
T: Sum<T> + Clone + Add<Output = T> + Zero,
{
ext1(1, &mut |y| sum_1(y.clone().attach_rank::<1>().unwrap()), x)
}
/// Matrix-multiply W with T, threading where necessary until the first argument has rank 2 and the
/// second argument has rank 1.
pub fn dot_2_1<A, Tag, Tag2>(
w: &DifferentiableTagged<A, Tag>,
t: &DifferentiableTagged<A, Tag2>,
) -> Differentiable<A>
where
A: NumLike + Default,
Tag: Clone,
Tag2: Clone,
{
assert!(
w.rank() >= 2,
"w needed to have rank 2 or more, was {}",
w.rank()
);
assert!(
t.rank() >= 1,
"t needed to have rank 1 or more, was {}",
t.rank()
);
sum(&star_2_1(w, t))
}
pub fn linear<A, Tag1, Tag2, Tag3>(
theta0: &DifferentiableTagged<A, Tag1>,
theta1: &DifferentiableTagged<A, Tag2>,
t: &DifferentiableTagged<A, Tag3>,
) -> DifferentiableTagged<A, ()>
where
A: NumLike + Default,
Tag1: Clone,
Tag2: Clone,
Tag3: Clone,
{
dot_2_1(theta0, t).map2_tagged(theta1, &mut |x, _, y, _| (x.clone() + y.clone(), ()))
}
pub fn relu<A, Tag1, Tag2, Tag3>(
t: &RankedDifferentiableTagged<A, Tag1, 1>,
theta0: &RankedDifferentiableTagged<A, Tag2, 2>,
theta1: &RankedDifferentiableTagged<A, Tag3, 1>,
) -> Differentiable<A>
where
A: NumLike + PartialOrd + Default,
Tag1: Clone,
Tag2: Clone,
Tag3: Clone,
{
linear(
theta0.to_unranked_borrow(),
theta1.to_unranked_borrow(),
t.to_unranked_borrow(),
)
.map(&mut rectify)
}
#[cfg(test)]
mod tests {
use crate::auto_diff::{Differentiable, RankedDifferentiable};
use crate::ext::{elementwise_mul_via_ext, ext1, ext2, star_2_1};
use crate::ext::{dot_2_1, ext1, relu, star_2_1};
use crate::not_nan::{to_not_nan_1, to_not_nan_2};
use crate::scalar::Scalar;
use crate::traits::Zero;
use ordered_float::NotNan;
use std::iter::Sum;
use std::ops::Mul;
fn zeros_redefined<A>(t: &Differentiable<A>) -> Differentiable<A>
where
@@ -269,4 +364,58 @@ mod tests {
]
)
}
#[test]
fn test_dot_2_1() {
let w = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
[2.0, 1.0, 3.1],
[3.7, 4.0, 6.1],
]));
let t = RankedDifferentiable::of_slice(&to_not_nan_1([1.3, 0.4, 3.3]));
let result = dot_2_1(w.to_unranked_borrow(), t.to_unranked_borrow())
.attach_rank::<1>()
.unwrap()
.to_vector()
.iter()
.map(|x| x.clone().to_scalar().clone_real_part().into_inner())
.collect::<Vec<_>>();
assert_eq!(result, [13.23, 26.54])
}
#[test]
fn test_relu() {
let weights = to_not_nan_2([
[7.1, 4.3, -6.4],
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[-1.3, -2.4, -3.6],
]);
let biases = to_not_nan_1([10.2, 11.3, 12.4, 13.5]);
let inputs = to_not_nan_1([7.0, 8.0, 9.0]);
let theta0 = RankedDifferentiable::of_slice_2::<_, 2>(&weights);
let theta1 = RankedDifferentiable::of_slice(&biases);
let t = RankedDifferentiable::of_slice(&inputs);
let result = relu(&t, &theta0, &theta1)
.into_vector()
.iter()
.map(|x| x.borrow_scalar().clone_real_part().into_inner())
.collect::<Vec<_>>();
let mut expected = Vec::new();
for (weights, bias) in weights.iter().zip(biases.iter()) {
expected.push(
crate::decider::relu(
&t,
&RankedDifferentiable::of_slice(weights),
Scalar::make(bias.clone()),
)
.clone_real_part()
.into_inner(),
);
}
assert_eq!(result, expected);
}
}

View File

@@ -4,6 +4,7 @@ use std::{
};
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
use crate::ext::{sum, sum_1};
use crate::{
auto_diff::{DifferentiableTagged, RankedDifferentiable},
scalar::Scalar,
@@ -76,16 +77,6 @@ where
RankedDifferentiable::map2(x, x, &mut |x, y| x.clone() * y.clone())
}
fn sum_2<A>(x: RankedDifferentiable<A, 1>) -> Scalar<A>
where
A: Sum<A> + Clone + Add<Output = A> + Zero,
{
RankedDifferentiable::to_vector(x)
.into_iter()
.map(|x| x.to_scalar())
.sum()
}
fn l2_norm_2<A>(
prediction: &RankedDifferentiable<A, 1>,
data: &RankedDifferentiable<A, 1>,
@@ -94,7 +85,7 @@ where
A: Sum<A> + Mul<Output = A> + Copy + Default + Neg<Output = A> + Add<Output = A> + Zero + Neg,
{
let diff = RankedDifferentiable::map2(prediction, data, &mut |x, y| x.clone() - y.clone());
sum_2(squared_2(&diff))
sum_1(squared_2(&diff)).into_scalar()
}
pub fn l2_loss_2<A, F, Params, const N: usize>(
@@ -249,12 +240,16 @@ where
.map(|v| RankedDifferentiable::of_scalar(v.borrow_scalar().clone()))
.collect::<Vec<_>>(),
);
let theta1 = theta[1].borrow_scalar().clone();
let theta1 = theta[1].clone().attach_rank::<0>().unwrap();
let dotted: Vec<_> = xs
.to_vector()
.into_iter()
.map(|point| sum_2(elementwise_mul(&theta0, &point)))
.map(|x| RankedDifferentiable::of_scalar(x + theta1.clone()))
.map(|point| {
sum(elementwise_mul(&theta0, &point).to_unranked_borrow())
.attach_rank::<0>()
.unwrap()
})
.map(|x| x.map2(&theta1, &mut |x, theta| x.clone() + theta.clone()))
.collect();
RankedDifferentiable::of_vector(dotted)
}