* Ext1

* ext1 tests

* ext2
This commit is contained in:
Patrick Stevens
2023-06-14 15:58:26 +01:00
committed by GitHub
parent 6ab19d4c4d
commit 242f71fa75
3 changed files with 344 additions and 0 deletions

View File

@@ -199,6 +199,34 @@ impl<A, Tag> DifferentiableContents<A, Tag> {
}
}
pub fn map_once_tagged<B, Tag2, F>(
self: &DifferentiableContents<A, Tag>,
mut f: F,
) -> DifferentiableContents<B, Tag2>
where
F: FnMut(&DifferentiableTagged<A, Tag>) -> DifferentiableTagged<B, Tag2>,
{
match self {
DifferentiableContents::Scalar(_, _) => {
panic!("can't map_once_tagged into a scalar");
}
DifferentiableContents::Vector(v, _rank) => {
assert_ne!(v.len(), 0, "Can't get rank of an empty vector");
let mut rank = 0;
DifferentiableContents::Vector(
v.iter()
.map(|x| {
let result = f(x);
rank = result.rank();
result
})
.collect(),
rank + 1,
)
}
}
}
/// Unwraps one layer of each input, so the passed function takes inputs which have decreased
/// the ranks of the `map2_once_tagged` input by one.
/// Panics if passed a scalar or if the input vectors are not the same length.
@@ -335,6 +363,18 @@ impl<A, Tag> DifferentiableTagged<A, Tag> {
}
}
pub fn map_once_tagged<F, B, Tag2>(
self: &DifferentiableTagged<A, Tag>,
f: F,
) -> DifferentiableTagged<B, Tag2>
where
F: FnMut(&DifferentiableTagged<A, Tag>) -> DifferentiableTagged<B, Tag2>,
{
DifferentiableTagged {
contents: self.contents.map_once_tagged(f),
}
}
pub fn map2_once_tagged<B, C, Tag2, Tag3, F>(
self: &DifferentiableTagged<A, Tag>,
other: &DifferentiableTagged<B, Tag2>,
@@ -519,6 +559,37 @@ pub struct RankedDifferentiableTagged<A, Tag, const RANK: usize> {
contents: DifferentiableTagged<A, Tag>,
}
impl<A, Tag, const RANK: usize> RankedDifferentiableTagged<A, Tag, RANK> {
pub fn map_once_tagged<B, Tag2, F, const RANK2: usize>(
&self,
f: &mut F,
) -> DifferentiableTagged<B, Tag2>
where
A: Clone,
Tag: Clone,
B: Clone,
Tag2: Clone,
F: FnMut(&RankedDifferentiableTagged<A, Tag, RANK2>) -> DifferentiableTagged<B, Tag2>,
{
match &self.contents.contents {
DifferentiableContents::Scalar(_, _) => {
panic!("forbidden by the types")
}
DifferentiableContents::Vector(v, rank) => {
assert_eq!(*rank, RANK2);
DifferentiableTagged {
contents: DifferentiableContents::Vector(
v.iter()
.map(|x| f(&(*x).clone().attach_rank::<RANK2>().unwrap()))
.collect(),
RANK2 - 1,
),
}
}
}
}
}
impl<A, Tag, const RANK: usize> Display for RankedDifferentiableTagged<A, Tag, RANK>
where
A: Display,

272
little_learner/src/ext.rs Normal file
View File

@@ -0,0 +1,272 @@
use crate::auto_diff::{Differentiable, DifferentiableTagged, RankedDifferentiable};
use std::iter::Sum;
use std::ops::Mul;
pub fn ext1<A, B, Tag, Tag2, F>(
n: usize,
f: &mut F,
t: &DifferentiableTagged<A, Tag>,
) -> DifferentiableTagged<B, Tag2>
where
F: FnMut(&DifferentiableTagged<A, Tag>) -> DifferentiableTagged<B, Tag2>,
{
if t.rank() == n {
f(t)
} else {
t.map_once_tagged(|x| ext1(n, f, x))
}
}
pub fn ext2<A, B, C, Tag, Tag2, Tag3, F>(
n: usize,
m: usize,
f: &mut F,
t: &DifferentiableTagged<A, Tag>,
u: &DifferentiableTagged<B, Tag2>,
) -> DifferentiableTagged<C, Tag3>
where
F: FnMut(
&DifferentiableTagged<A, Tag>,
&DifferentiableTagged<B, Tag2>,
) -> DifferentiableTagged<C, Tag3>,
A: Clone,
Tag: Clone,
B: Clone,
Tag2: Clone,
{
if t.rank() == n && u.rank() == m {
f(t, u)
} else if t.rank() == n {
u.map_once_tagged(|eu| ext2(n, m, f, t, eu))
} else if u.rank() == m {
t.map_once_tagged(|et| ext2(n, m, f, et, u))
} else if t.rank() == u.rank() {
t.map2_once_tagged(u, |t, u| ext2(n, m, f, t, u))
} else if t.rank() > u.rank() {
t.map_once_tagged(|et| ext2(n, m, f, et, u))
} else {
u.map_once_tagged(|eu| ext2(n, m, f, t, eu))
}
}
pub fn elementwise_mul_via_ext<A, const RANK1: usize, const RANK2: usize>(
x: &RankedDifferentiable<A, RANK1>,
y: &RankedDifferentiable<A, RANK2>,
) -> RankedDifferentiable<A, RANK1>
where
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Clone + Default,
{
ext2(
0,
0,
&mut |x, y| Differentiable::of_scalar(x.clone().into_scalar() * y.clone().into_scalar()),
x.to_unranked_borrow(),
y.to_unranked_borrow(),
)
.attach_rank::<RANK1>()
.unwrap()
}
/// Produce the matrix multiplication of the inputs, threading where necessary until the
/// first argument has rank 2 and the second argument has rank 1.
pub fn star_2_1<T>(x: &Differentiable<T>, y: &Differentiable<T>) -> Differentiable<T>
where
T: Clone + Sum + Mul<Output = T> + Default,
{
ext2(
2,
1,
&mut |x, y| {
elementwise_mul_via_ext(
&x.clone().attach_rank::<2>().unwrap(),
&y.clone().attach_rank::<1>().unwrap(),
)
.to_unranked()
},
x,
y,
)
}
#[cfg(test)]
mod tests {
use crate::auto_diff::{Differentiable, RankedDifferentiable};
use crate::ext::{elementwise_mul_via_ext, ext1, ext2, star_2_1};
use crate::not_nan::{to_not_nan_1, to_not_nan_2};
use crate::scalar::Scalar;
use crate::traits::Zero;
use ordered_float::NotNan;
use std::iter::Sum;
use std::ops::Mul;
fn zeros_redefined<A>(t: &Differentiable<A>) -> Differentiable<A>
where
A: Zero,
{
ext1(
0,
&mut |_| Differentiable::of_scalar(Scalar::make(A::zero())),
t,
)
}
#[test]
fn define_zeros() {
let shape = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
[1.0, 2.0],
[3.0, 4.0],
[5.0, 6.0],
]));
let zeros = zeros_redefined(&shape.to_unranked());
let to_zeros = zeros
.attach_rank::<2>()
.unwrap()
.to_vector()
.iter()
.map(|x| {
(*x).clone()
.to_vector()
.iter()
.map(|x| (*x).clone().to_scalar().clone_real_part().into_inner())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
assert_eq!(to_zeros, [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
}
fn flatten_2<A>(t: RankedDifferentiable<A, 2>) -> RankedDifferentiable<A, 1>
where
A: Clone,
{
let mut result = Vec::new();
for v in t.to_unranked_borrow().borrow_vector() {
result.extend((*v.borrow_vector()).clone())
}
Differentiable::of_vec(result).attach_rank::<1>().unwrap()
}
#[test]
fn test_flatten_2() {
let input = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
[1.0, 0.5],
[3.1, 2.2],
[7.3, 2.1],
]));
let flattened = flatten_2(input);
let reshaped = flattened
.to_vector()
.iter()
.map(|x| (*x).clone().to_scalar().clone_real_part().into_inner())
.collect::<Vec<_>>();
assert_eq!(reshaped, [1.0, 0.5, 3.1, 2.2, 7.3, 2.1])
}
#[test]
fn test_flatten() {
let flatten = |t: &Differentiable<NotNan<f64>>| {
ext1(
2,
&mut |t| flatten_2((*t).clone().attach_rank::<2>().unwrap()).to_unranked(),
t,
)
};
let input = RankedDifferentiable::of_vector(vec![
RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
[1.0, 0.5],
[3.1, 2.2],
[7.3, 2.1],
])),
RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
[2.9, 3.5],
[0.7, 1.5],
[2.5, 6.4],
])),
]);
let flattened = flatten(&input.to_unranked())
.attach_rank::<2>()
.unwrap()
.to_vector()
.iter()
.map(|i| {
i.to_unranked_borrow()
.borrow_vector()
.iter()
.map(|j| j.borrow_scalar().clone_real_part().into_inner())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
assert_eq!(
flattened,
[
[1.0, 0.5, 3.1, 2.2, 7.3, 2.1],
[2.9, 3.5, 0.7, 1.5, 2.5, 6.4]
]
)
}
#[test]
fn test_star_2_1_a() {
let input1 = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
[3.0, 4.0, 5.0],
[7.0, 8.0, 9.0],
]));
let input2 = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 4.0, 3.0]));
let output = star_2_1(input1.to_unranked_borrow(), input2.to_unranked_borrow())
.into_vector()
.iter()
.map(|x| {
x.clone()
.into_vector()
.iter()
.map(|i| i.clone().into_scalar().clone_real_part().into_inner())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
assert_eq!(output, [[6.0, 16.0, 15.0], [14.0, 32.0, 27.0]])
}
#[test]
fn test_star_2_1_b() {
let input1 = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
[8.0, 1.0],
[7.0, 3.0],
[5.0, 4.0],
]));
let input2 = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
[6.0, 2.0],
[4.0, 9.0],
[3.0, 8.0],
]));
let output = star_2_1(input1.to_unranked_borrow(), input2.to_unranked_borrow())
.into_vector()
.iter()
.map(|x| {
x.clone()
.into_vector()
.iter()
.map(|i| {
i.clone()
.into_vector()
.iter()
.map(|i| i.borrow_scalar().clone_real_part().into_inner())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
assert_eq!(
output,
[
[[48.0, 2.0], [42.0, 6.0], [30.0, 8.0]],
[[32.0, 9.0], [28.0, 27.0], [20.0, 36.0]],
[[24.0, 8.0], [21.0, 24.0], [15.0, 32.0]]
]
)
}
}

View File

@@ -4,6 +4,7 @@
pub mod auto_diff;
pub mod decider;
pub mod ext;
pub mod gradient_descent;
pub mod hyper;
pub mod layer;