Add ext form of relu (#27)
This commit is contained in:
@@ -3,7 +3,7 @@ use crate::loss::dot;
|
||||
use crate::scalar::Scalar;
|
||||
use crate::traits::{NumLike, Zero};
|
||||
|
||||
fn rectify<A>(x: A) -> A
|
||||
pub(crate) fn rectify<A>(x: A) -> A
|
||||
where
|
||||
A: Zero + PartialOrd,
|
||||
{
|
||||
|
@@ -1,6 +1,11 @@
|
||||
use crate::auto_diff::{Differentiable, DifferentiableTagged, RankedDifferentiable};
|
||||
use crate::auto_diff::{
|
||||
Differentiable, DifferentiableTagged, RankedDifferentiable, RankedDifferentiableTagged,
|
||||
};
|
||||
use crate::decider::rectify;
|
||||
use crate::scalar::Scalar;
|
||||
use crate::traits::{NumLike, Zero};
|
||||
use std::iter::Sum;
|
||||
use std::ops::Mul;
|
||||
use std::ops::{Add, Mul};
|
||||
|
||||
pub fn ext1<A, B, Tag, Tag2, F>(
|
||||
n: usize,
|
||||
@@ -49,17 +54,21 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn elementwise_mul_via_ext<A, const RANK1: usize, const RANK2: usize>(
|
||||
x: &RankedDifferentiable<A, RANK1>,
|
||||
y: &RankedDifferentiable<A, RANK2>,
|
||||
pub fn elementwise_mul_via_ext<A, Tag, Tag2, const RANK1: usize, const RANK2: usize>(
|
||||
x: &RankedDifferentiableTagged<A, Tag, RANK1>,
|
||||
y: &RankedDifferentiableTagged<A, Tag2, RANK2>,
|
||||
) -> RankedDifferentiable<A, RANK1>
|
||||
where
|
||||
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Clone + Default,
|
||||
Tag: Clone,
|
||||
Tag2: Clone,
|
||||
{
|
||||
ext2(
|
||||
0,
|
||||
0,
|
||||
&mut |x, y| Differentiable::of_scalar(x.clone().into_scalar() * y.clone().into_scalar()),
|
||||
&mut |x, y| {
|
||||
DifferentiableTagged::of_scalar(x.borrow_scalar().clone() * y.borrow_scalar().clone())
|
||||
},
|
||||
x.to_unranked_borrow(),
|
||||
y.to_unranked_borrow(),
|
||||
)
|
||||
@@ -67,11 +76,18 @@ where
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Produce the matrix multiplication of the inputs, threading where necessary until the
|
||||
/// Produce the element-wise multiplication of the inputs, threading where necessary until the
|
||||
/// first argument has rank 2 and the second argument has rank 1.
|
||||
pub fn star_2_1<T>(x: &Differentiable<T>, y: &Differentiable<T>) -> Differentiable<T>
|
||||
/// This is essentially "matrix-multiply a matrix by a vector, but don't do the sum; instead
|
||||
/// leave the components to be summed in a vector".
|
||||
pub fn star_2_1<T, Tag, Tag2>(
|
||||
x: &DifferentiableTagged<T, Tag>,
|
||||
y: &DifferentiableTagged<T, Tag2>,
|
||||
) -> Differentiable<T>
|
||||
where
|
||||
T: Clone + Sum + Mul<Output = T> + Default,
|
||||
Tag: Clone,
|
||||
Tag2: Clone,
|
||||
{
|
||||
ext2(
|
||||
2,
|
||||
@@ -88,16 +104,95 @@ where
|
||||
)
|
||||
}
|
||||
|
||||
fn sum_1_scalar<A, Tag>(x: RankedDifferentiableTagged<A, Tag, 1>) -> Scalar<A>
|
||||
where
|
||||
A: Sum<A> + Clone + Add<Output = A> + Zero,
|
||||
{
|
||||
RankedDifferentiableTagged::to_vector(x)
|
||||
.into_iter()
|
||||
.map(|x| x.to_scalar())
|
||||
.sum()
|
||||
}
|
||||
|
||||
pub fn sum_1<A, Tag>(x: RankedDifferentiableTagged<A, Tag, 1>) -> Differentiable<A>
|
||||
where
|
||||
A: Sum<A> + Clone + Add<Output = A> + Zero,
|
||||
{
|
||||
DifferentiableTagged::of_scalar(sum_1_scalar(x))
|
||||
}
|
||||
|
||||
pub fn sum<T>(x: &Differentiable<T>) -> Differentiable<T>
|
||||
where
|
||||
T: Sum<T> + Clone + Add<Output = T> + Zero,
|
||||
{
|
||||
ext1(1, &mut |y| sum_1(y.clone().attach_rank::<1>().unwrap()), x)
|
||||
}
|
||||
|
||||
/// Matrix-multiply W with T, threading where necessary until the first argument has rank 2 and the
|
||||
/// second argument has rank 1.
|
||||
pub fn dot_2_1<A, Tag, Tag2>(
|
||||
w: &DifferentiableTagged<A, Tag>,
|
||||
t: &DifferentiableTagged<A, Tag2>,
|
||||
) -> Differentiable<A>
|
||||
where
|
||||
A: NumLike + Default,
|
||||
Tag: Clone,
|
||||
Tag2: Clone,
|
||||
{
|
||||
assert!(
|
||||
w.rank() >= 2,
|
||||
"w needed to have rank 2 or more, was {}",
|
||||
w.rank()
|
||||
);
|
||||
assert!(
|
||||
t.rank() >= 1,
|
||||
"t needed to have rank 1 or more, was {}",
|
||||
t.rank()
|
||||
);
|
||||
sum(&star_2_1(w, t))
|
||||
}
|
||||
|
||||
pub fn linear<A, Tag1, Tag2, Tag3>(
|
||||
theta0: &DifferentiableTagged<A, Tag1>,
|
||||
theta1: &DifferentiableTagged<A, Tag2>,
|
||||
t: &DifferentiableTagged<A, Tag3>,
|
||||
) -> DifferentiableTagged<A, ()>
|
||||
where
|
||||
A: NumLike + Default,
|
||||
Tag1: Clone,
|
||||
Tag2: Clone,
|
||||
Tag3: Clone,
|
||||
{
|
||||
dot_2_1(theta0, t).map2_tagged(theta1, &mut |x, _, y, _| (x.clone() + y.clone(), ()))
|
||||
}
|
||||
|
||||
pub fn relu<A, Tag1, Tag2, Tag3>(
|
||||
t: &RankedDifferentiableTagged<A, Tag1, 1>,
|
||||
theta0: &RankedDifferentiableTagged<A, Tag2, 2>,
|
||||
theta1: &RankedDifferentiableTagged<A, Tag3, 1>,
|
||||
) -> Differentiable<A>
|
||||
where
|
||||
A: NumLike + PartialOrd + Default,
|
||||
Tag1: Clone,
|
||||
Tag2: Clone,
|
||||
Tag3: Clone,
|
||||
{
|
||||
linear(
|
||||
theta0.to_unranked_borrow(),
|
||||
theta1.to_unranked_borrow(),
|
||||
t.to_unranked_borrow(),
|
||||
)
|
||||
.map(&mut rectify)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::auto_diff::{Differentiable, RankedDifferentiable};
|
||||
use crate::ext::{elementwise_mul_via_ext, ext1, ext2, star_2_1};
|
||||
use crate::ext::{dot_2_1, ext1, relu, star_2_1};
|
||||
use crate::not_nan::{to_not_nan_1, to_not_nan_2};
|
||||
use crate::scalar::Scalar;
|
||||
use crate::traits::Zero;
|
||||
use ordered_float::NotNan;
|
||||
use std::iter::Sum;
|
||||
use std::ops::Mul;
|
||||
|
||||
fn zeros_redefined<A>(t: &Differentiable<A>) -> Differentiable<A>
|
||||
where
|
||||
@@ -269,4 +364,58 @@ mod tests {
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_2_1() {
|
||||
let w = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
|
||||
[2.0, 1.0, 3.1],
|
||||
[3.7, 4.0, 6.1],
|
||||
]));
|
||||
let t = RankedDifferentiable::of_slice(&to_not_nan_1([1.3, 0.4, 3.3]));
|
||||
|
||||
let result = dot_2_1(w.to_unranked_borrow(), t.to_unranked_borrow())
|
||||
.attach_rank::<1>()
|
||||
.unwrap()
|
||||
.to_vector()
|
||||
.iter()
|
||||
.map(|x| x.clone().to_scalar().clone_real_part().into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(result, [13.23, 26.54])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu() {
|
||||
let weights = to_not_nan_2([
|
||||
[7.1, 4.3, -6.4],
|
||||
[1.0, 2.0, 3.0],
|
||||
[4.0, 5.0, 6.0],
|
||||
[-1.3, -2.4, -3.6],
|
||||
]);
|
||||
let biases = to_not_nan_1([10.2, 11.3, 12.4, 13.5]);
|
||||
let inputs = to_not_nan_1([7.0, 8.0, 9.0]);
|
||||
let theta0 = RankedDifferentiable::of_slice_2::<_, 2>(&weights);
|
||||
let theta1 = RankedDifferentiable::of_slice(&biases);
|
||||
let t = RankedDifferentiable::of_slice(&inputs);
|
||||
|
||||
let result = relu(&t, &theta0, &theta1)
|
||||
.into_vector()
|
||||
.iter()
|
||||
.map(|x| x.borrow_scalar().clone_real_part().into_inner())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut expected = Vec::new();
|
||||
for (weights, bias) in weights.iter().zip(biases.iter()) {
|
||||
expected.push(
|
||||
crate::decider::relu(
|
||||
&t,
|
||||
&RankedDifferentiable::of_slice(weights),
|
||||
Scalar::make(bias.clone()),
|
||||
)
|
||||
.clone_real_part()
|
||||
.into_inner(),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(result, expected);
|
||||
}
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@ use std::{
|
||||
};
|
||||
|
||||
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
|
||||
use crate::ext::{sum, sum_1};
|
||||
use crate::{
|
||||
auto_diff::{DifferentiableTagged, RankedDifferentiable},
|
||||
scalar::Scalar,
|
||||
@@ -76,16 +77,6 @@ where
|
||||
RankedDifferentiable::map2(x, x, &mut |x, y| x.clone() * y.clone())
|
||||
}
|
||||
|
||||
fn sum_2<A>(x: RankedDifferentiable<A, 1>) -> Scalar<A>
|
||||
where
|
||||
A: Sum<A> + Clone + Add<Output = A> + Zero,
|
||||
{
|
||||
RankedDifferentiable::to_vector(x)
|
||||
.into_iter()
|
||||
.map(|x| x.to_scalar())
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn l2_norm_2<A>(
|
||||
prediction: &RankedDifferentiable<A, 1>,
|
||||
data: &RankedDifferentiable<A, 1>,
|
||||
@@ -94,7 +85,7 @@ where
|
||||
A: Sum<A> + Mul<Output = A> + Copy + Default + Neg<Output = A> + Add<Output = A> + Zero + Neg,
|
||||
{
|
||||
let diff = RankedDifferentiable::map2(prediction, data, &mut |x, y| x.clone() - y.clone());
|
||||
sum_2(squared_2(&diff))
|
||||
sum_1(squared_2(&diff)).into_scalar()
|
||||
}
|
||||
|
||||
pub fn l2_loss_2<A, F, Params, const N: usize>(
|
||||
@@ -249,12 +240,16 @@ where
|
||||
.map(|v| RankedDifferentiable::of_scalar(v.borrow_scalar().clone()))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
let theta1 = theta[1].borrow_scalar().clone();
|
||||
let theta1 = theta[1].clone().attach_rank::<0>().unwrap();
|
||||
let dotted: Vec<_> = xs
|
||||
.to_vector()
|
||||
.into_iter()
|
||||
.map(|point| sum_2(elementwise_mul(&theta0, &point)))
|
||||
.map(|x| RankedDifferentiable::of_scalar(x + theta1.clone()))
|
||||
.map(|point| {
|
||||
sum(elementwise_mul(&theta0, &point).to_unranked_borrow())
|
||||
.attach_rank::<0>()
|
||||
.unwrap()
|
||||
})
|
||||
.map(|x| x.map2(&theta1, &mut |x, theta| x.clone() + theta.clone()))
|
||||
.collect();
|
||||
RankedDifferentiable::of_vector(dotted)
|
||||
}
|
||||
|
Reference in New Issue
Block a user