Separate implementation (#12)

This commit is contained in:
Patrick Stevens
2023-04-08 00:50:32 +01:00
committed by GitHub
parent 753722d7ca
commit 1b738b200a
4 changed files with 226 additions and 121 deletions

4
Cargo.lock generated
View File

@@ -101,9 +101,9 @@ dependencies = [
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.53" version = "1.0.56"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73" checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435"
dependencies = [ dependencies = [
"unicode-ident", "unicode-ident",
] ]

View File

@@ -12,7 +12,9 @@ where
A: Zero, A: Zero,
{ {
fn zero() -> Differentiable<A> { fn zero() -> Differentiable<A> {
Differentiable::Scalar(Scalar::Number(A::zero(), None)) Differentiable {
contents: DifferentiableContents::Scalar(Scalar::Number(A::zero(), None)),
}
} }
} }
@@ -30,7 +32,21 @@ where
A: One, A: One,
{ {
fn one() -> Differentiable<A> { fn one() -> Differentiable<A> {
Differentiable::Scalar(Scalar::one()) Differentiable {
contents: DifferentiableContents::Scalar(Scalar::one()),
}
}
}
impl<A> Clone for DifferentiableContents<A>
where
A: Clone,
{
fn clone(&self) -> Self {
match self {
Self::Scalar(arg0) => Self::Scalar(arg0.clone()),
Self::Vector(arg0, rank) => Self::Vector(arg0.clone(), *rank),
}
} }
} }
@@ -39,27 +55,32 @@ where
A: Clone, A: Clone,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
match self { Differentiable {
Self::Scalar(arg0) => Self::Scalar(arg0.clone()), contents: self.contents.clone(),
Self::Vector(arg0) => Self::Vector(arg0.clone()),
} }
} }
} }
#[derive(Debug)] #[derive(Debug)]
pub enum Differentiable<A> { enum DifferentiableContents<A> {
Scalar(Scalar<A>), Scalar(Scalar<A>),
Vector(Vec<Differentiable<A>>), // Contains the rank.
Vector(Vec<Differentiable<A>>, usize),
} }
impl<A> Display for Differentiable<A> #[derive(Debug)]
pub struct Differentiable<A> {
contents: DifferentiableContents<A>,
}
impl<A> Display for DifferentiableContents<A>
where where
A: Display, A: Display,
{ {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {
Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)), DifferentiableContents::Scalar(s) => f.write_fmt(format_args!("{}", s)),
Differentiable::Vector(v) => { DifferentiableContents::Vector(v, _rank) => {
f.write_char('[')?; f.write_char('[')?;
for v in v.iter() { for v in v.iter() {
f.write_fmt(format_args!("{}", v))?; f.write_fmt(format_args!("{}", v))?;
@@ -71,106 +92,196 @@ where
} }
} }
impl<A> Differentiable<A> { impl<A> Display for Differentiable<A>
pub fn map<B, F>(&self, f: &mut F) -> Differentiable<B> where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}", self.contents))
}
}
impl<A> DifferentiableContents<A> {
fn map<B, F>(&self, f: &mut F) -> DifferentiableContents<B>
where where
F: FnMut(Scalar<A>) -> Scalar<B>, F: FnMut(Scalar<A>) -> Scalar<B>,
A: Clone, A: Clone,
{ {
match self { match self {
Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())), DifferentiableContents::Scalar(a) => DifferentiableContents::Scalar(f(a.clone())),
Differentiable::Vector(slice) => { DifferentiableContents::Vector(slice, rank) => {
Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect()) DifferentiableContents::Vector(slice.iter().map(|x| x.map(f)).collect(), *rank)
} }
} }
} }
fn map2<B, C, F>(&self, other: &DifferentiableContents<B>, f: &F) -> DifferentiableContents<C>
where
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
A: Clone,
B: Clone,
{
match (self, other) {
(DifferentiableContents::Scalar(a), DifferentiableContents::Scalar(b)) => {
DifferentiableContents::Scalar(f(a, b))
}
(
DifferentiableContents::Vector(slice_a, rank_a),
DifferentiableContents::Vector(slice_b, rank_b),
) => {
if rank_a != rank_b {
panic!("Unexpectedly different ranks in map2");
}
DifferentiableContents::Vector(
slice_a
.iter()
.zip(slice_b.iter())
.map(|(a, b)| a.map2(b, f))
.collect(),
*rank_a,
)
}
_ => panic!("Wrong shapes!"),
}
}
fn of_slice<T>(input: T) -> DifferentiableContents<A>
where
A: Clone,
T: AsRef<[A]>,
{
DifferentiableContents::Vector(
input
.as_ref()
.iter()
.map(|v| Differentiable {
contents: DifferentiableContents::Scalar(Scalar::Number((*v).clone(), None)),
})
.collect(),
1,
)
}
fn rank(&self) -> usize {
match self {
DifferentiableContents::Scalar(_) => 0,
DifferentiableContents::Vector(_, rank) => *rank,
}
}
}
impl<A> Differentiable<A> {
pub fn map<B, F>(&self, f: &mut F) -> Differentiable<B>
where
A: Clone,
F: FnMut(Scalar<A>) -> Scalar<B>,
{
Differentiable {
contents: self.contents.map(f),
}
}
pub fn map2<B, C, F>(&self, other: &Differentiable<B>, f: &F) -> Differentiable<C> pub fn map2<B, C, F>(&self, other: &Differentiable<B>, f: &F) -> Differentiable<C>
where where
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>, F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
A: Clone, A: Clone,
B: Clone, B: Clone,
{ {
match (self, other) { Differentiable {
(Differentiable::Scalar(a), Differentiable::Scalar(b)) => { contents: self.contents.map2(&other.contents, f),
Differentiable::Scalar(f(a, b))
} }
(Differentiable::Vector(slice_a), Differentiable::Vector(slice_b)) => {
Differentiable::Vector(
slice_a
.iter()
.zip(slice_b.iter())
.map(|(a, b)| a.map2(b, f))
.collect(),
)
} }
_ => panic!("Wrong shapes!"),
pub fn attach_rank<const RANK: usize>(
self: Differentiable<A>,
) -> Option<RankedDifferentiable<A, RANK>> {
if self.contents.rank() == RANK {
Some(RankedDifferentiable { contents: self })
} else {
None
} }
} }
pub fn of_scalar(s: Scalar<A>) -> Differentiable<A> {
Differentiable {
contents: DifferentiableContents::Scalar(s),
}
}
}
impl<A> DifferentiableContents<A> {
fn into_scalar(self) -> Scalar<A> {
match self {
DifferentiableContents::Scalar(s) => s,
DifferentiableContents::Vector(_, _) => panic!("not a scalar"),
}
}
fn into_vector(self) -> Vec<Differentiable<A>> {
match self {
DifferentiableContents::Scalar(_) => panic!("not a vector"),
DifferentiableContents::Vector(v, _) => v,
}
}
fn borrow_scalar(&self) -> &Scalar<A> {
match self {
DifferentiableContents::Scalar(s) => s,
DifferentiableContents::Vector(_, _) => panic!("not a scalar"),
}
}
fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
match self {
DifferentiableContents::Scalar(_) => panic!("not a vector"),
DifferentiableContents::Vector(v, _) => v,
}
}
}
impl<A> Differentiable<A> {
pub fn into_scalar(self) -> Scalar<A> {
self.contents.into_scalar()
}
pub fn into_vector(self) -> Vec<Differentiable<A>> {
self.contents.into_vector()
}
pub fn borrow_scalar(&self) -> &Scalar<A> {
self.contents.borrow_scalar()
}
pub fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
self.contents.borrow_vector()
}
fn of_slice<T>(input: T) -> Differentiable<A> fn of_slice<T>(input: T) -> Differentiable<A>
where where
A: Clone, A: Clone,
T: AsRef<[A]>, T: AsRef<[A]>,
{ {
Differentiable::Vector( Differentiable {
input contents: DifferentiableContents::of_slice(input),
.as_ref() }
.iter() }
.map(|v| Differentiable::Scalar(Scalar::Number((*v).clone(), None)))
.collect(), pub fn of_vec(input: Vec<Differentiable<A>>) -> Differentiable<A> {
) if input.is_empty() {
panic!("Can't make an empty tensor");
}
let rank = input[0].rank();
Differentiable {
contents: DifferentiableContents::Vector(input, 1 + rank),
}
} }
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
match self { self.contents.rank()
Differentiable::Scalar(_) => 0,
Differentiable::Vector(v) => v[0].rank() + 1,
} }
} }
pub fn attach_rank<const RANK: usize>( impl<A> DifferentiableContents<A>
self: Differentiable<A>,
) -> Option<RankedDifferentiable<A, RANK>> {
if self.rank() == RANK {
Some(RankedDifferentiable { contents: self })
} else {
None
}
}
}
impl<A> Differentiable<A> {
pub fn into_scalar(self) -> Scalar<A> {
match self {
Differentiable::Scalar(s) => s,
Differentiable::Vector(_) => panic!("not a scalar"),
}
}
pub fn into_vector(self) -> Vec<Differentiable<A>> {
match self {
Differentiable::Scalar(_) => panic!("not a vector"),
Differentiable::Vector(v) => v,
}
}
pub fn borrow_scalar(&self) -> &Scalar<A> {
match self {
Differentiable::Scalar(s) => s,
Differentiable::Vector(_) => panic!("not a scalar"),
}
}
pub fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
match self {
Differentiable::Scalar(_) => panic!("not a vector"),
Differentiable::Vector(v) => v,
}
}
}
impl<A> Differentiable<A>
where where
A: Clone A: Clone
+ Eq + Eq
@@ -185,17 +296,19 @@ where
{ {
fn accumulate_gradients_vec(v: &[Differentiable<A>], acc: &mut HashMap<Scalar<A>, A>) { fn accumulate_gradients_vec(v: &[Differentiable<A>], acc: &mut HashMap<Scalar<A>, A>) {
for v in v.iter().rev() { for v in v.iter().rev() {
v.accumulate_gradients(acc); v.contents.accumulate_gradients(acc);
} }
} }
fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) { fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) {
match self { match self {
Differentiable::Scalar(y) => { DifferentiableContents::Scalar(y) => {
let k = y.clone_link(); let k = y.clone_link();
k.invoke(y, A::one(), acc); k.invoke(y, A::one(), acc);
} }
Differentiable::Vector(y) => Differentiable::accumulate_gradients_vec(y, acc), DifferentiableContents::Vector(y, _rank) => {
DifferentiableContents::accumulate_gradients_vec(y, acc)
}
} }
} }
@@ -231,15 +344,12 @@ where
impl<A> RankedDifferentiable<A, 0> { impl<A> RankedDifferentiable<A, 0> {
pub fn to_scalar(self) -> Scalar<A> { pub fn to_scalar(self) -> Scalar<A> {
match self.contents { self.contents.contents.into_scalar()
Differentiable::Scalar(s) => s,
Differentiable::Vector(_) => panic!("not a scalar despite teq that we're a scalar"),
}
} }
pub fn of_scalar(s: Scalar<A>) -> RankedDifferentiable<A, 0> { pub fn of_scalar(s: Scalar<A>) -> RankedDifferentiable<A, 0> {
RankedDifferentiable { RankedDifferentiable {
contents: Differentiable::Scalar(s), contents: Differentiable::of_scalar(s),
} }
} }
} }
@@ -251,7 +361,9 @@ impl<A> RankedDifferentiable<A, 1> {
T: AsRef<[A]>, T: AsRef<[A]>,
{ {
RankedDifferentiable { RankedDifferentiable {
contents: Differentiable::of_slice(input), contents: Differentiable {
contents: DifferentiableContents::of_slice(input),
},
} }
} }
} }
@@ -267,7 +379,7 @@ impl<A> RankedDifferentiable<A, 2> {
.map(|x| Differentiable::of_slice(x)) .map(|x| Differentiable::of_slice(x))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
RankedDifferentiable { RankedDifferentiable {
contents: Differentiable::Vector(v), contents: Differentiable::of_vec(v),
} }
} }
} }
@@ -285,7 +397,7 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
s: Vec<RankedDifferentiable<A, RANK>>, s: Vec<RankedDifferentiable<A, RANK>>,
) -> RankedDifferentiable<A, { RANK + 1 }> { ) -> RankedDifferentiable<A, { RANK + 1 }> {
RankedDifferentiable { RankedDifferentiable {
contents: Differentiable::Vector(s.into_iter().map(|v| v.contents).collect()), contents: Differentiable::of_vec(s.into_iter().map(|v| v.contents).collect()),
} }
} }
@@ -320,13 +432,11 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
pub fn to_vector( pub fn to_vector(
self: RankedDifferentiable<A, RANK>, self: RankedDifferentiable<A, RANK>,
) -> Vec<RankedDifferentiable<A, { RANK - 1 }>> { ) -> Vec<RankedDifferentiable<A, { RANK - 1 }>> {
match self.contents { self.contents
Differentiable::Scalar(_) => panic!("not a scalar"), .into_vector()
Differentiable::Vector(v) => v
.into_iter() .into_iter()
.map(|v| RankedDifferentiable { contents: v }) .map(|v| RankedDifferentiable { contents: v })
.collect(), .collect()
}
} }
} }
@@ -357,7 +467,7 @@ where
}) })
}); });
let after_f = f(&wrt); let after_f = f(&wrt);
Differentiable::grad_once(after_f.contents, wrt) DifferentiableContents::grad_once(after_f.contents.contents, wrt)
} }
#[cfg(test)] #[cfg(test)]
@@ -369,21 +479,18 @@ mod tests {
use super::*; use super::*;
fn extract_scalar<'a, A>(d: &'a Differentiable<A>) -> &'a A { fn extract_scalar<'a, A>(d: &'a Differentiable<A>) -> &'a A {
match d { d.borrow_scalar().real_part()
Differentiable::Scalar(a) => &(a.real_part()),
Differentiable::Vector(_) => panic!("not a scalar"),
}
} }
#[test] #[test]
fn test_map() { fn test_map() {
let v = Differentiable::Vector( let v = Differentiable::of_vec(
vec![ vec![
Differentiable::Scalar(Scalar::Number( Differentiable::of_scalar(Scalar::Number(
NotNan::new(3.0).expect("3 is not NaN"), NotNan::new(3.0).expect("3 is not NaN"),
Some(0usize), Some(0usize),
)), )),
Differentiable::Scalar(Scalar::Number( Differentiable::of_scalar(Scalar::Number(
NotNan::new(4.0).expect("4 is not NaN"), NotNan::new(4.0).expect("4 is not NaN"),
Some(1usize), Some(1usize),
)), )),
@@ -395,13 +502,11 @@ mod tests {
Scalar::Dual(_, _) => panic!("Not hit"), Scalar::Dual(_, _) => panic!("Not hit"),
}); });
let v = match mapped { let v = mapped
Differentiable::Scalar(_) => panic!("Not a scalar"), .into_vector()
Differentiable::Vector(v) => v
.iter() .iter()
.map(|d| extract_scalar(d).clone()) .map(|d| extract_scalar(d).clone())
.collect::<Vec<_>>(), .collect::<Vec<_>>();
};
assert_eq!(v, [4.0, 5.0]); assert_eq!(v, [4.0, 5.0]);
} }

View File

@@ -126,7 +126,7 @@ where
let dotted = RankedDifferentiable::of_scalar( let dotted = RankedDifferentiable::of_scalar(
dot_unranked( dot_unranked(
left_arg.to_unranked_borrow(), left_arg.to_unranked_borrow(),
&Differentiable::Vector(theta.to_vec()), &Differentiable::of_vec(theta.to_vec()),
) )
.into_vector() .into_vector()
.into_iter() .into_iter()
@@ -180,7 +180,7 @@ where
); );
dot_unranked( dot_unranked(
x_powers.to_unranked_borrow(), x_powers.to_unranked_borrow(),
&Differentiable::Vector(theta.to_vec()), &Differentiable::of_vec(theta.to_vec()),
) )
.attach_rank::<1>() .attach_rank::<1>()
.expect("wanted a tensor1") .expect("wanted a tensor1")

View File

@@ -106,7 +106,7 @@ fn main() {
}, },
[ [
RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(), RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(),
Differentiable::Scalar(Scalar::zero()), Differentiable::of_scalar(Scalar::zero()),
], ],
hyper.iterations, hyper.iterations,
) )
@@ -168,7 +168,7 @@ mod tests {
#[test] #[test]
fn grad_example() { fn grad_example() {
let input_vec = [Differentiable::Scalar(Scalar::make( let input_vec = [Differentiable::of_scalar(Scalar::make(
NotNan::new(27.0).expect("not nan"), NotNan::new(27.0).expect("not nan"),
))]; ))];
@@ -362,7 +362,7 @@ mod tests {
}, },
[ [
RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(), RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(),
Differentiable::Scalar(Scalar::zero()), Differentiable::of_scalar(Scalar::zero()),
], ],
hyper.iterations, hyper.iterations,
) )