Baby's first layer (#25)

This commit is contained in:
Patrick Stevens
2023-06-14 14:16:56 +01:00
committed by GitHub
parent 1ee76d4bc3
commit 6ab19d4c4d
5 changed files with 220 additions and 12 deletions

View File

@@ -69,7 +69,7 @@ where
#[derive(Debug)]
enum DifferentiableContents<A, Tag> {
Scalar(Scalar<A>, Tag),
// Contains the rank.
// Contains the rank of this differentiable (i.e. one more than the rank of the inputs).
Vector(Vec<DifferentiableTagged<A, Tag>>, usize),
}
@@ -199,6 +199,64 @@ impl<A, Tag> DifferentiableContents<A, Tag> {
}
}
/// Unwraps one layer of each input, so the passed function takes inputs which have decreased
/// the ranks of the `map2_once_tagged` input by one.
/// Panics if passed a scalar or if the input vectors are not the same length.
pub fn map2_once_tagged<B, C, Tag2, Tag3, F>(
self: &DifferentiableContents<A, Tag>,
other: &DifferentiableContents<B, Tag2>,
mut f: F,
) -> DifferentiableContents<C, Tag3>
where
F: FnMut(
&DifferentiableTagged<A, Tag>,
&DifferentiableTagged<B, Tag2>,
) -> DifferentiableTagged<C, Tag3>,
{
match (self, other) {
(DifferentiableContents::Scalar(_, _), _) => {
panic!("First arg needed to have non-scalar rank")
}
(_, DifferentiableContents::Scalar(_, _)) => {
panic!("Second arg needed to have non-scalar rank")
}
(
DifferentiableContents::Vector(v1, rank1),
DifferentiableContents::Vector(v2, _rank2),
) => {
assert_eq!(
v1.len(),
v2.len(),
"Must map two vectors of the same length, got {rank1} and {_rank2}"
);
assert_ne!(
v1.len(),
0,
"Cannot determine a rank of a zero-length vector"
);
let mut rank = 0usize;
DifferentiableContents::Vector(
v1.iter()
.zip(v2.iter())
.map(|(a, b)| {
let result = f(a, b);
match result.contents {
DifferentiableContents::Vector(_, discovered_rank) => {
rank = discovered_rank + 1;
}
DifferentiableContents::Scalar(_, _) => {
rank = 1;
}
}
result
})
.collect(),
rank,
)
}
}
}
fn of_slice<'a, T, I>(tag: Tag, input: I) -> DifferentiableContents<T, Tag>
where
T: Clone + 'a,
@@ -277,6 +335,22 @@ impl<A, Tag> DifferentiableTagged<A, Tag> {
}
}
pub fn map2_once_tagged<B, C, Tag2, Tag3, F>(
self: &DifferentiableTagged<A, Tag>,
other: &DifferentiableTagged<B, Tag2>,
f: F,
) -> DifferentiableTagged<C, Tag3>
where
F: FnMut(
&DifferentiableTagged<A, Tag>,
&DifferentiableTagged<B, Tag2>,
) -> DifferentiableTagged<C, Tag3>,
{
DifferentiableTagged {
contents: self.contents.map2_once_tagged(&other.contents, f),
}
}
pub fn attach_rank<const RANK: usize>(
self: DifferentiableTagged<A, Tag>,
) -> Option<RankedDifferentiableTagged<A, Tag, RANK>> {
@@ -582,10 +656,10 @@ impl<A, Tag, const RANK: usize> RankedDifferentiableTagged<A, Tag, RANK> {
}
}
pub fn map2_tagged<B, C, Tag2, Tag3, F>(
self: &RankedDifferentiableTagged<A, Tag, RANK>,
other: &RankedDifferentiableTagged<B, Tag2, RANK>,
f: &mut F,
pub fn map2_tagged<'a, 'b, B, C, Tag2, Tag3, F>(
self: &'a RankedDifferentiableTagged<A, Tag, RANK>,
other: &'a RankedDifferentiableTagged<B, Tag2, RANK>,
f: &'b mut F,
) -> RankedDifferentiableTagged<C, Tag3, RANK>
where
F: FnMut(&Scalar<A>, Tag, &Scalar<B>, Tag2) -> (Scalar<C>, Tag3),
@@ -598,6 +672,44 @@ impl<A, Tag, const RANK: usize> RankedDifferentiableTagged<A, Tag, RANK> {
contents: DifferentiableTagged::map2_tagged(&self.contents, &other.contents, f),
}
}
pub fn map2_once_tagged<
'a,
'c,
B,
C: 'a,
Tag2,
Tag3: 'a,
F,
const RANK_B: usize,
const RANK_OUT: usize,
>(
self: &'a RankedDifferentiableTagged<A, Tag, RANK>,
other: &'a RankedDifferentiableTagged<B, Tag2, RANK_B>,
f: &'c mut F,
) -> RankedDifferentiableTagged<C, Tag3, RANK_OUT>
where
F: FnMut(
&RankedDifferentiableTagged<A, Tag, { RANK - 1 }>,
&RankedDifferentiableTagged<B, Tag2, { RANK_B - 1 }>,
) -> RankedDifferentiableTagged<C, Tag3, { RANK_OUT - 1 }>,
A: Clone,
B: Clone,
Tag: Clone,
Tag2: Clone,
'c: 'a,
{
RankedDifferentiableTagged {
contents: DifferentiableTagged::map2_once_tagged(
&self.contents,
&other.contents,
&mut |a: &DifferentiableTagged<A, Tag>, b: &DifferentiableTagged<B, Tag2>| {
let a = (*a).clone().attach_rank::<{ RANK - 1 }>().unwrap();
let b = (*b).clone().attach_rank::<{ RANK_B - 1 }>().unwrap();
f(&a, &b).to_unranked()
},
),
}
}
pub fn to_vector(
self: RankedDifferentiableTagged<A, Tag, RANK>,
@@ -634,6 +746,22 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
{
self.map2_tagged(other, &mut |a, (), b, ()| (f(a, b), ()))
}
pub fn map2_once<B, C, F, const RANK_B: usize, const RANK_OUT: usize>(
self: &RankedDifferentiable<A, RANK>,
other: &RankedDifferentiable<B, RANK_B>,
f: &mut F,
) -> RankedDifferentiable<C, RANK_OUT>
where
F: FnMut(
&RankedDifferentiable<A, { RANK - 1 }>,
&RankedDifferentiable<B, { RANK_B - 1 }>,
) -> RankedDifferentiable<C, { RANK_OUT - 1 }>,
A: Clone,
B: Clone,
{
self.map2_once_tagged(other, f)
}
}
pub fn grad<A, Tag, F, const RANK: usize, const PARAM_RANK: usize>(

View File

@@ -15,19 +15,19 @@ where
}
fn linear<A, Tag1, Tag2>(
t: RankedDifferentiableTagged<A, Tag1, 1>,
theta0: RankedDifferentiableTagged<A, Tag2, 1>,
t: &RankedDifferentiableTagged<A, Tag1, 1>,
theta0: &RankedDifferentiableTagged<A, Tag2, 1>,
theta1: Scalar<A>,
) -> Scalar<A>
where
A: NumLike,
{
dot(&theta0, &t) + theta1
dot(theta0, t) + theta1
}
pub fn relu<A, Tag1, Tag2>(
t: RankedDifferentiableTagged<A, Tag1, 1>,
theta0: RankedDifferentiableTagged<A, Tag2, 1>,
t: &RankedDifferentiableTagged<A, Tag1, 1>,
theta0: &RankedDifferentiableTagged<A, Tag2, 1>,
theta1: Scalar<A>,
) -> Scalar<A>
where
@@ -50,7 +50,7 @@ mod test_decider {
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
let result = linear(t, theta0, theta1).real_part().into_inner();
let result = linear(&t, &theta0, theta1).real_part().into_inner();
assert!((result + 0.1).abs() < 0.000_000_01);
}
@@ -61,7 +61,7 @@ mod test_decider {
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
let result = relu(t, theta0, theta1).real_part().into_inner();
let result = relu(&t, &theta0, theta1).real_part().into_inner();
assert_eq!(result, 0.0);
}

View File

@@ -0,0 +1,75 @@
use crate::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
use crate::decider::relu;
use crate::traits::NumLike;
/// Returns a tensor1.
/// Theta has two components: a tensor2 of weights and a tensor1 of bias.
pub fn layer<T>(
theta: Differentiable<T>,
t: RankedDifferentiable<T, 1>,
) -> RankedDifferentiable<T, 1>
where
T: NumLike + PartialOrd,
{
let mut theta = theta.into_vector();
assert_eq!(theta.len(), 2, "Needed weights and a bias");
let b = theta.pop().unwrap().attach_rank::<1>().unwrap();
let w = theta.pop().unwrap().attach_rank::<2>().unwrap();
RankedDifferentiableTagged::map2_once(
&w,
&b,
&mut |w: &RankedDifferentiable<_, 1>, b: &RankedDifferentiable<_, 0>| {
RankedDifferentiableTagged::of_scalar(relu(&t, w, b.clone().to_scalar()))
},
)
}
#[cfg(test)]
mod tests {
use crate::auto_diff::{Differentiable, RankedDifferentiable};
use crate::layer::layer;
use crate::not_nan::{to_not_nan_1, to_not_nan_2};
#[test]
fn test_single_layer() {
let b = RankedDifferentiable::of_slice(&to_not_nan_1([1.0, 2.0]));
let w = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
[3.0, 4.0, 5.0],
[6.0, 7.0, 8.0],
]));
let theta = Differentiable::of_vec(vec![w.to_unranked(), b.to_unranked()]);
/*
Two neurons:
w =
(3 4 5
6 7 8)
b = (1, 2)
Three inputs:
t = (9, 10, 11)
Output has two elements, one per neuron.
Neuron 1 has weights (3,4,5) and bias 1;
Neuron 2 has weights (6,7,8) and bias 2.
Neuron 1 is relu(t, (3,4,5), 1), which is (9, 10, 11).(3, 4, 5) + 1.
Neuron 2 is relu(t, (6,7,8), 2), which is (9, 10, 11).(6, 7, 8) + 2.
*/
let t = RankedDifferentiable::of_slice(&to_not_nan_1([9.0, 10.0, 11.0]));
let mut output = layer(theta, t)
.to_vector()
.iter()
.map(|t| (*t).clone().to_scalar().clone_real_part().into_inner())
.collect::<Vec<_>>();
assert_eq!(output.len(), 2);
let result_2 = output.pop().unwrap();
let result_1 = output.pop().unwrap();
assert_eq!(result_1, (9 * 3 + 10 * 4 + 11 * 5 + 1) as f64);
assert_eq!(result_2, (9 * 6 + 10 * 7 + 11 * 8 + 2) as f64);
}
}

View File

@@ -6,6 +6,7 @@ pub mod auto_diff;
pub mod decider;
pub mod gradient_descent;
pub mod hyper;
pub mod layer;
pub mod loss;
pub mod not_nan;
pub mod predictor;

View File

@@ -7,6 +7,10 @@ where
xs.map(|x| NotNan::new(x).expect("not nan"))
}
pub fn from_not_nan_1<T, const N: usize>(xs: [NotNan<T>; N]) -> [T; N] {
xs.map(|x| x.into_inner())
}
pub fn to_not_nan_2<T, const N: usize, const M: usize>(xs: [[T; N]; M]) -> [[NotNan<T>; N]; M]
where
T: ordered_float::Float,