Separate implementation (#12)

This commit is contained in:
Patrick Stevens
2023-04-08 00:50:32 +01:00
committed by GitHub
parent 753722d7ca
commit 1b738b200a
4 changed files with 226 additions and 121 deletions

4
Cargo.lock generated
View File

@@ -101,9 +101,9 @@ dependencies = [
[[package]]
name = "proc-macro2"
version = "1.0.53"
version = "1.0.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73"
checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435"
dependencies = [
"unicode-ident",
]

View File

@@ -12,7 +12,9 @@ where
A: Zero,
{
fn zero() -> Differentiable<A> {
Differentiable::Scalar(Scalar::Number(A::zero(), None))
Differentiable {
contents: DifferentiableContents::Scalar(Scalar::Number(A::zero(), None)),
}
}
}
@@ -30,7 +32,21 @@ where
A: One,
{
fn one() -> Differentiable<A> {
Differentiable::Scalar(Scalar::one())
Differentiable {
contents: DifferentiableContents::Scalar(Scalar::one()),
}
}
}
impl<A> Clone for DifferentiableContents<A>
where
A: Clone,
{
fn clone(&self) -> Self {
match self {
Self::Scalar(arg0) => Self::Scalar(arg0.clone()),
Self::Vector(arg0, rank) => Self::Vector(arg0.clone(), *rank),
}
}
}
@@ -39,27 +55,32 @@ where
A: Clone,
{
fn clone(&self) -> Self {
match self {
Self::Scalar(arg0) => Self::Scalar(arg0.clone()),
Self::Vector(arg0) => Self::Vector(arg0.clone()),
Differentiable {
contents: self.contents.clone(),
}
}
}
#[derive(Debug)]
pub enum Differentiable<A> {
enum DifferentiableContents<A> {
Scalar(Scalar<A>),
Vector(Vec<Differentiable<A>>),
// Contains the rank.
Vector(Vec<Differentiable<A>>, usize),
}
impl<A> Display for Differentiable<A>
#[derive(Debug)]
pub struct Differentiable<A> {
contents: DifferentiableContents<A>,
}
impl<A> Display for DifferentiableContents<A>
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)),
Differentiable::Vector(v) => {
DifferentiableContents::Scalar(s) => f.write_fmt(format_args!("{}", s)),
DifferentiableContents::Vector(v, _rank) => {
f.write_char('[')?;
for v in v.iter() {
f.write_fmt(format_args!("{}", v))?;
@@ -71,106 +92,196 @@ where
}
}
impl<A> Differentiable<A> {
pub fn map<B, F>(&self, f: &mut F) -> Differentiable<B>
impl<A> Display for Differentiable<A>
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{}", self.contents))
}
}
impl<A> DifferentiableContents<A> {
fn map<B, F>(&self, f: &mut F) -> DifferentiableContents<B>
where
F: FnMut(Scalar<A>) -> Scalar<B>,
A: Clone,
{
match self {
Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())),
Differentiable::Vector(slice) => {
Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect())
DifferentiableContents::Scalar(a) => DifferentiableContents::Scalar(f(a.clone())),
DifferentiableContents::Vector(slice, rank) => {
DifferentiableContents::Vector(slice.iter().map(|x| x.map(f)).collect(), *rank)
}
}
}
fn map2<B, C, F>(&self, other: &DifferentiableContents<B>, f: &F) -> DifferentiableContents<C>
where
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
A: Clone,
B: Clone,
{
match (self, other) {
(DifferentiableContents::Scalar(a), DifferentiableContents::Scalar(b)) => {
DifferentiableContents::Scalar(f(a, b))
}
(
DifferentiableContents::Vector(slice_a, rank_a),
DifferentiableContents::Vector(slice_b, rank_b),
) => {
if rank_a != rank_b {
panic!("Unexpectedly different ranks in map2");
}
DifferentiableContents::Vector(
slice_a
.iter()
.zip(slice_b.iter())
.map(|(a, b)| a.map2(b, f))
.collect(),
*rank_a,
)
}
_ => panic!("Wrong shapes!"),
}
}
fn of_slice<T>(input: T) -> DifferentiableContents<A>
where
A: Clone,
T: AsRef<[A]>,
{
DifferentiableContents::Vector(
input
.as_ref()
.iter()
.map(|v| Differentiable {
contents: DifferentiableContents::Scalar(Scalar::Number((*v).clone(), None)),
})
.collect(),
1,
)
}
fn rank(&self) -> usize {
match self {
DifferentiableContents::Scalar(_) => 0,
DifferentiableContents::Vector(_, rank) => *rank,
}
}
}
impl<A> Differentiable<A> {
pub fn map<B, F>(&self, f: &mut F) -> Differentiable<B>
where
A: Clone,
F: FnMut(Scalar<A>) -> Scalar<B>,
{
Differentiable {
contents: self.contents.map(f),
}
}
pub fn map2<B, C, F>(&self, other: &Differentiable<B>, f: &F) -> Differentiable<C>
where
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
A: Clone,
B: Clone,
{
match (self, other) {
(Differentiable::Scalar(a), Differentiable::Scalar(b)) => {
Differentiable::Scalar(f(a, b))
Differentiable {
contents: self.contents.map2(&other.contents, f),
}
(Differentiable::Vector(slice_a), Differentiable::Vector(slice_b)) => {
Differentiable::Vector(
slice_a
.iter()
.zip(slice_b.iter())
.map(|(a, b)| a.map2(b, f))
.collect(),
)
}
_ => panic!("Wrong shapes!"),
pub fn attach_rank<const RANK: usize>(
self: Differentiable<A>,
) -> Option<RankedDifferentiable<A, RANK>> {
if self.contents.rank() == RANK {
Some(RankedDifferentiable { contents: self })
} else {
None
}
}
pub fn of_scalar(s: Scalar<A>) -> Differentiable<A> {
Differentiable {
contents: DifferentiableContents::Scalar(s),
}
}
}
impl<A> DifferentiableContents<A> {
fn into_scalar(self) -> Scalar<A> {
match self {
DifferentiableContents::Scalar(s) => s,
DifferentiableContents::Vector(_, _) => panic!("not a scalar"),
}
}
fn into_vector(self) -> Vec<Differentiable<A>> {
match self {
DifferentiableContents::Scalar(_) => panic!("not a vector"),
DifferentiableContents::Vector(v, _) => v,
}
}
fn borrow_scalar(&self) -> &Scalar<A> {
match self {
DifferentiableContents::Scalar(s) => s,
DifferentiableContents::Vector(_, _) => panic!("not a scalar"),
}
}
fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
match self {
DifferentiableContents::Scalar(_) => panic!("not a vector"),
DifferentiableContents::Vector(v, _) => v,
}
}
}
impl<A> Differentiable<A> {
pub fn into_scalar(self) -> Scalar<A> {
self.contents.into_scalar()
}
pub fn into_vector(self) -> Vec<Differentiable<A>> {
self.contents.into_vector()
}
pub fn borrow_scalar(&self) -> &Scalar<A> {
self.contents.borrow_scalar()
}
pub fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
self.contents.borrow_vector()
}
fn of_slice<T>(input: T) -> Differentiable<A>
where
A: Clone,
T: AsRef<[A]>,
{
Differentiable::Vector(
input
.as_ref()
.iter()
.map(|v| Differentiable::Scalar(Scalar::Number((*v).clone(), None)))
.collect(),
)
Differentiable {
contents: DifferentiableContents::of_slice(input),
}
}
pub fn of_vec(input: Vec<Differentiable<A>>) -> Differentiable<A> {
if input.is_empty() {
panic!("Can't make an empty tensor");
}
let rank = input[0].rank();
Differentiable {
contents: DifferentiableContents::Vector(input, 1 + rank),
}
}
pub fn rank(&self) -> usize {
match self {
Differentiable::Scalar(_) => 0,
Differentiable::Vector(v) => v[0].rank() + 1,
self.contents.rank()
}
}
pub fn attach_rank<const RANK: usize>(
self: Differentiable<A>,
) -> Option<RankedDifferentiable<A, RANK>> {
if self.rank() == RANK {
Some(RankedDifferentiable { contents: self })
} else {
None
}
}
}
impl<A> Differentiable<A> {
pub fn into_scalar(self) -> Scalar<A> {
match self {
Differentiable::Scalar(s) => s,
Differentiable::Vector(_) => panic!("not a scalar"),
}
}
pub fn into_vector(self) -> Vec<Differentiable<A>> {
match self {
Differentiable::Scalar(_) => panic!("not a vector"),
Differentiable::Vector(v) => v,
}
}
pub fn borrow_scalar(&self) -> &Scalar<A> {
match self {
Differentiable::Scalar(s) => s,
Differentiable::Vector(_) => panic!("not a scalar"),
}
}
pub fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
match self {
Differentiable::Scalar(_) => panic!("not a vector"),
Differentiable::Vector(v) => v,
}
}
}
impl<A> Differentiable<A>
impl<A> DifferentiableContents<A>
where
A: Clone
+ Eq
@@ -185,17 +296,19 @@ where
{
fn accumulate_gradients_vec(v: &[Differentiable<A>], acc: &mut HashMap<Scalar<A>, A>) {
for v in v.iter().rev() {
v.accumulate_gradients(acc);
v.contents.accumulate_gradients(acc);
}
}
fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) {
match self {
Differentiable::Scalar(y) => {
DifferentiableContents::Scalar(y) => {
let k = y.clone_link();
k.invoke(y, A::one(), acc);
}
Differentiable::Vector(y) => Differentiable::accumulate_gradients_vec(y, acc),
DifferentiableContents::Vector(y, _rank) => {
DifferentiableContents::accumulate_gradients_vec(y, acc)
}
}
}
@@ -231,15 +344,12 @@ where
impl<A> RankedDifferentiable<A, 0> {
pub fn to_scalar(self) -> Scalar<A> {
match self.contents {
Differentiable::Scalar(s) => s,
Differentiable::Vector(_) => panic!("not a scalar despite teq that we're a scalar"),
}
self.contents.contents.into_scalar()
}
pub fn of_scalar(s: Scalar<A>) -> RankedDifferentiable<A, 0> {
RankedDifferentiable {
contents: Differentiable::Scalar(s),
contents: Differentiable::of_scalar(s),
}
}
}
@@ -251,7 +361,9 @@ impl<A> RankedDifferentiable<A, 1> {
T: AsRef<[A]>,
{
RankedDifferentiable {
contents: Differentiable::of_slice(input),
contents: Differentiable {
contents: DifferentiableContents::of_slice(input),
},
}
}
}
@@ -267,7 +379,7 @@ impl<A> RankedDifferentiable<A, 2> {
.map(|x| Differentiable::of_slice(x))
.collect::<Vec<_>>();
RankedDifferentiable {
contents: Differentiable::Vector(v),
contents: Differentiable::of_vec(v),
}
}
}
@@ -285,7 +397,7 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
s: Vec<RankedDifferentiable<A, RANK>>,
) -> RankedDifferentiable<A, { RANK + 1 }> {
RankedDifferentiable {
contents: Differentiable::Vector(s.into_iter().map(|v| v.contents).collect()),
contents: Differentiable::of_vec(s.into_iter().map(|v| v.contents).collect()),
}
}
@@ -320,13 +432,11 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
pub fn to_vector(
self: RankedDifferentiable<A, RANK>,
) -> Vec<RankedDifferentiable<A, { RANK - 1 }>> {
match self.contents {
Differentiable::Scalar(_) => panic!("not a scalar"),
Differentiable::Vector(v) => v
self.contents
.into_vector()
.into_iter()
.map(|v| RankedDifferentiable { contents: v })
.collect(),
}
.collect()
}
}
@@ -357,7 +467,7 @@ where
})
});
let after_f = f(&wrt);
Differentiable::grad_once(after_f.contents, wrt)
DifferentiableContents::grad_once(after_f.contents.contents, wrt)
}
#[cfg(test)]
@@ -369,21 +479,18 @@ mod tests {
use super::*;
fn extract_scalar<'a, A>(d: &'a Differentiable<A>) -> &'a A {
match d {
Differentiable::Scalar(a) => &(a.real_part()),
Differentiable::Vector(_) => panic!("not a scalar"),
}
d.borrow_scalar().real_part()
}
#[test]
fn test_map() {
let v = Differentiable::Vector(
let v = Differentiable::of_vec(
vec![
Differentiable::Scalar(Scalar::Number(
Differentiable::of_scalar(Scalar::Number(
NotNan::new(3.0).expect("3 is not NaN"),
Some(0usize),
)),
Differentiable::Scalar(Scalar::Number(
Differentiable::of_scalar(Scalar::Number(
NotNan::new(4.0).expect("4 is not NaN"),
Some(1usize),
)),
@@ -395,13 +502,11 @@ mod tests {
Scalar::Dual(_, _) => panic!("Not hit"),
});
let v = match mapped {
Differentiable::Scalar(_) => panic!("Not a scalar"),
Differentiable::Vector(v) => v
let v = mapped
.into_vector()
.iter()
.map(|d| extract_scalar(d).clone())
.collect::<Vec<_>>(),
};
.collect::<Vec<_>>();
assert_eq!(v, [4.0, 5.0]);
}

View File

@@ -126,7 +126,7 @@ where
let dotted = RankedDifferentiable::of_scalar(
dot_unranked(
left_arg.to_unranked_borrow(),
&Differentiable::Vector(theta.to_vec()),
&Differentiable::of_vec(theta.to_vec()),
)
.into_vector()
.into_iter()
@@ -180,7 +180,7 @@ where
);
dot_unranked(
x_powers.to_unranked_borrow(),
&Differentiable::Vector(theta.to_vec()),
&Differentiable::of_vec(theta.to_vec()),
)
.attach_rank::<1>()
.expect("wanted a tensor1")

View File

@@ -106,7 +106,7 @@ fn main() {
},
[
RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(),
Differentiable::Scalar(Scalar::zero()),
Differentiable::of_scalar(Scalar::zero()),
],
hyper.iterations,
)
@@ -168,7 +168,7 @@ mod tests {
#[test]
fn grad_example() {
let input_vec = [Differentiable::Scalar(Scalar::make(
let input_vec = [Differentiable::of_scalar(Scalar::make(
NotNan::new(27.0).expect("not nan"),
))];
@@ -362,7 +362,7 @@ mod tests {
},
[
RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(),
Differentiable::Scalar(Scalar::zero()),
Differentiable::of_scalar(Scalar::zero()),
],
hyper.iterations,
)