Separate implementation (#12)
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -101,9 +101,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.53"
|
||||
version = "1.0.56"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73"
|
||||
checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
@@ -12,7 +12,9 @@ where
|
||||
A: Zero,
|
||||
{
|
||||
fn zero() -> Differentiable<A> {
|
||||
Differentiable::Scalar(Scalar::Number(A::zero(), None))
|
||||
Differentiable {
|
||||
contents: DifferentiableContents::Scalar(Scalar::Number(A::zero(), None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +32,21 @@ where
|
||||
A: One,
|
||||
{
|
||||
fn one() -> Differentiable<A> {
|
||||
Differentiable::Scalar(Scalar::one())
|
||||
Differentiable {
|
||||
contents: DifferentiableContents::Scalar(Scalar::one()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Clone for DifferentiableContents<A>
|
||||
where
|
||||
A: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Scalar(arg0) => Self::Scalar(arg0.clone()),
|
||||
Self::Vector(arg0, rank) => Self::Vector(arg0.clone(), *rank),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,27 +55,32 @@ where
|
||||
A: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Scalar(arg0) => Self::Scalar(arg0.clone()),
|
||||
Self::Vector(arg0) => Self::Vector(arg0.clone()),
|
||||
Differentiable {
|
||||
contents: self.contents.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Differentiable<A> {
|
||||
enum DifferentiableContents<A> {
|
||||
Scalar(Scalar<A>),
|
||||
Vector(Vec<Differentiable<A>>),
|
||||
// Contains the rank.
|
||||
Vector(Vec<Differentiable<A>>, usize),
|
||||
}
|
||||
|
||||
impl<A> Display for Differentiable<A>
|
||||
#[derive(Debug)]
|
||||
pub struct Differentiable<A> {
|
||||
contents: DifferentiableContents<A>,
|
||||
}
|
||||
|
||||
impl<A> Display for DifferentiableContents<A>
|
||||
where
|
||||
A: Display,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)),
|
||||
Differentiable::Vector(v) => {
|
||||
DifferentiableContents::Scalar(s) => f.write_fmt(format_args!("{}", s)),
|
||||
DifferentiableContents::Vector(v, _rank) => {
|
||||
f.write_char('[')?;
|
||||
for v in v.iter() {
|
||||
f.write_fmt(format_args!("{}", v))?;
|
||||
@@ -71,106 +92,196 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Differentiable<A> {
|
||||
pub fn map<B, F>(&self, f: &mut F) -> Differentiable<B>
|
||||
impl<A> Display for Differentiable<A>
|
||||
where
|
||||
A: Display,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!("{}", self.contents))
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> DifferentiableContents<A> {
|
||||
fn map<B, F>(&self, f: &mut F) -> DifferentiableContents<B>
|
||||
where
|
||||
F: FnMut(Scalar<A>) -> Scalar<B>,
|
||||
A: Clone,
|
||||
{
|
||||
match self {
|
||||
Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())),
|
||||
Differentiable::Vector(slice) => {
|
||||
Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect())
|
||||
DifferentiableContents::Scalar(a) => DifferentiableContents::Scalar(f(a.clone())),
|
||||
DifferentiableContents::Vector(slice, rank) => {
|
||||
DifferentiableContents::Vector(slice.iter().map(|x| x.map(f)).collect(), *rank)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn map2<B, C, F>(&self, other: &DifferentiableContents<B>, f: &F) -> DifferentiableContents<C>
|
||||
where
|
||||
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
|
||||
A: Clone,
|
||||
B: Clone,
|
||||
{
|
||||
match (self, other) {
|
||||
(DifferentiableContents::Scalar(a), DifferentiableContents::Scalar(b)) => {
|
||||
DifferentiableContents::Scalar(f(a, b))
|
||||
}
|
||||
(
|
||||
DifferentiableContents::Vector(slice_a, rank_a),
|
||||
DifferentiableContents::Vector(slice_b, rank_b),
|
||||
) => {
|
||||
if rank_a != rank_b {
|
||||
panic!("Unexpectedly different ranks in map2");
|
||||
}
|
||||
DifferentiableContents::Vector(
|
||||
slice_a
|
||||
.iter()
|
||||
.zip(slice_b.iter())
|
||||
.map(|(a, b)| a.map2(b, f))
|
||||
.collect(),
|
||||
*rank_a,
|
||||
)
|
||||
}
|
||||
_ => panic!("Wrong shapes!"),
|
||||
}
|
||||
}
|
||||
|
||||
fn of_slice<T>(input: T) -> DifferentiableContents<A>
|
||||
where
|
||||
A: Clone,
|
||||
T: AsRef<[A]>,
|
||||
{
|
||||
DifferentiableContents::Vector(
|
||||
input
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|v| Differentiable {
|
||||
contents: DifferentiableContents::Scalar(Scalar::Number((*v).clone(), None)),
|
||||
})
|
||||
.collect(),
|
||||
1,
|
||||
)
|
||||
}
|
||||
|
||||
fn rank(&self) -> usize {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(_) => 0,
|
||||
DifferentiableContents::Vector(_, rank) => *rank,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Differentiable<A> {
|
||||
pub fn map<B, F>(&self, f: &mut F) -> Differentiable<B>
|
||||
where
|
||||
A: Clone,
|
||||
F: FnMut(Scalar<A>) -> Scalar<B>,
|
||||
{
|
||||
Differentiable {
|
||||
contents: self.contents.map(f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map2<B, C, F>(&self, other: &Differentiable<B>, f: &F) -> Differentiable<C>
|
||||
where
|
||||
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
|
||||
A: Clone,
|
||||
B: Clone,
|
||||
{
|
||||
match (self, other) {
|
||||
(Differentiable::Scalar(a), Differentiable::Scalar(b)) => {
|
||||
Differentiable::Scalar(f(a, b))
|
||||
}
|
||||
(Differentiable::Vector(slice_a), Differentiable::Vector(slice_b)) => {
|
||||
Differentiable::Vector(
|
||||
slice_a
|
||||
.iter()
|
||||
.zip(slice_b.iter())
|
||||
.map(|(a, b)| a.map2(b, f))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
_ => panic!("Wrong shapes!"),
|
||||
Differentiable {
|
||||
contents: self.contents.map2(&other.contents, f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn attach_rank<const RANK: usize>(
|
||||
self: Differentiable<A>,
|
||||
) -> Option<RankedDifferentiable<A, RANK>> {
|
||||
if self.contents.rank() == RANK {
|
||||
Some(RankedDifferentiable { contents: self })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn of_scalar(s: Scalar<A>) -> Differentiable<A> {
|
||||
Differentiable {
|
||||
contents: DifferentiableContents::Scalar(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> DifferentiableContents<A> {
|
||||
fn into_scalar(self) -> Scalar<A> {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(s) => s,
|
||||
DifferentiableContents::Vector(_, _) => panic!("not a scalar"),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_vector(self) -> Vec<Differentiable<A>> {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(_) => panic!("not a vector"),
|
||||
DifferentiableContents::Vector(v, _) => v,
|
||||
}
|
||||
}
|
||||
|
||||
fn borrow_scalar(&self) -> &Scalar<A> {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(s) => s,
|
||||
DifferentiableContents::Vector(_, _) => panic!("not a scalar"),
|
||||
}
|
||||
}
|
||||
|
||||
fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(_) => panic!("not a vector"),
|
||||
DifferentiableContents::Vector(v, _) => v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Differentiable<A> {
|
||||
pub fn into_scalar(self) -> Scalar<A> {
|
||||
self.contents.into_scalar()
|
||||
}
|
||||
|
||||
pub fn into_vector(self) -> Vec<Differentiable<A>> {
|
||||
self.contents.into_vector()
|
||||
}
|
||||
|
||||
pub fn borrow_scalar(&self) -> &Scalar<A> {
|
||||
self.contents.borrow_scalar()
|
||||
}
|
||||
|
||||
pub fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
|
||||
self.contents.borrow_vector()
|
||||
}
|
||||
|
||||
fn of_slice<T>(input: T) -> Differentiable<A>
|
||||
where
|
||||
A: Clone,
|
||||
T: AsRef<[A]>,
|
||||
{
|
||||
Differentiable::Vector(
|
||||
input
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|v| Differentiable::Scalar(Scalar::Number((*v).clone(), None)))
|
||||
.collect(),
|
||||
)
|
||||
Differentiable {
|
||||
contents: DifferentiableContents::of_slice(input),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn of_vec(input: Vec<Differentiable<A>>) -> Differentiable<A> {
|
||||
if input.is_empty() {
|
||||
panic!("Can't make an empty tensor");
|
||||
}
|
||||
let rank = input[0].rank();
|
||||
Differentiable {
|
||||
contents: DifferentiableContents::Vector(input, 1 + rank),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
match self {
|
||||
Differentiable::Scalar(_) => 0,
|
||||
Differentiable::Vector(v) => v[0].rank() + 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn attach_rank<const RANK: usize>(
|
||||
self: Differentiable<A>,
|
||||
) -> Option<RankedDifferentiable<A, RANK>> {
|
||||
if self.rank() == RANK {
|
||||
Some(RankedDifferentiable { contents: self })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
self.contents.rank()
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Differentiable<A> {
|
||||
pub fn into_scalar(self) -> Scalar<A> {
|
||||
match self {
|
||||
Differentiable::Scalar(s) => s,
|
||||
Differentiable::Vector(_) => panic!("not a scalar"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_vector(self) -> Vec<Differentiable<A>> {
|
||||
match self {
|
||||
Differentiable::Scalar(_) => panic!("not a vector"),
|
||||
Differentiable::Vector(v) => v,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn borrow_scalar(&self) -> &Scalar<A> {
|
||||
match self {
|
||||
Differentiable::Scalar(s) => s,
|
||||
Differentiable::Vector(_) => panic!("not a scalar"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
|
||||
match self {
|
||||
Differentiable::Scalar(_) => panic!("not a vector"),
|
||||
Differentiable::Vector(v) => v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Differentiable<A>
|
||||
impl<A> DifferentiableContents<A>
|
||||
where
|
||||
A: Clone
|
||||
+ Eq
|
||||
@@ -185,17 +296,19 @@ where
|
||||
{
|
||||
fn accumulate_gradients_vec(v: &[Differentiable<A>], acc: &mut HashMap<Scalar<A>, A>) {
|
||||
for v in v.iter().rev() {
|
||||
v.accumulate_gradients(acc);
|
||||
v.contents.accumulate_gradients(acc);
|
||||
}
|
||||
}
|
||||
|
||||
fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) {
|
||||
match self {
|
||||
Differentiable::Scalar(y) => {
|
||||
DifferentiableContents::Scalar(y) => {
|
||||
let k = y.clone_link();
|
||||
k.invoke(y, A::one(), acc);
|
||||
}
|
||||
Differentiable::Vector(y) => Differentiable::accumulate_gradients_vec(y, acc),
|
||||
DifferentiableContents::Vector(y, _rank) => {
|
||||
DifferentiableContents::accumulate_gradients_vec(y, acc)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,15 +344,12 @@ where
|
||||
|
||||
impl<A> RankedDifferentiable<A, 0> {
|
||||
pub fn to_scalar(self) -> Scalar<A> {
|
||||
match self.contents {
|
||||
Differentiable::Scalar(s) => s,
|
||||
Differentiable::Vector(_) => panic!("not a scalar despite teq that we're a scalar"),
|
||||
}
|
||||
self.contents.contents.into_scalar()
|
||||
}
|
||||
|
||||
pub fn of_scalar(s: Scalar<A>) -> RankedDifferentiable<A, 0> {
|
||||
RankedDifferentiable {
|
||||
contents: Differentiable::Scalar(s),
|
||||
contents: Differentiable::of_scalar(s),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -251,7 +361,9 @@ impl<A> RankedDifferentiable<A, 1> {
|
||||
T: AsRef<[A]>,
|
||||
{
|
||||
RankedDifferentiable {
|
||||
contents: Differentiable::of_slice(input),
|
||||
contents: Differentiable {
|
||||
contents: DifferentiableContents::of_slice(input),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -267,7 +379,7 @@ impl<A> RankedDifferentiable<A, 2> {
|
||||
.map(|x| Differentiable::of_slice(x))
|
||||
.collect::<Vec<_>>();
|
||||
RankedDifferentiable {
|
||||
contents: Differentiable::Vector(v),
|
||||
contents: Differentiable::of_vec(v),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -285,7 +397,7 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
|
||||
s: Vec<RankedDifferentiable<A, RANK>>,
|
||||
) -> RankedDifferentiable<A, { RANK + 1 }> {
|
||||
RankedDifferentiable {
|
||||
contents: Differentiable::Vector(s.into_iter().map(|v| v.contents).collect()),
|
||||
contents: Differentiable::of_vec(s.into_iter().map(|v| v.contents).collect()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -320,13 +432,11 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
|
||||
pub fn to_vector(
|
||||
self: RankedDifferentiable<A, RANK>,
|
||||
) -> Vec<RankedDifferentiable<A, { RANK - 1 }>> {
|
||||
match self.contents {
|
||||
Differentiable::Scalar(_) => panic!("not a scalar"),
|
||||
Differentiable::Vector(v) => v
|
||||
.into_iter()
|
||||
.map(|v| RankedDifferentiable { contents: v })
|
||||
.collect(),
|
||||
}
|
||||
self.contents
|
||||
.into_vector()
|
||||
.into_iter()
|
||||
.map(|v| RankedDifferentiable { contents: v })
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,7 +467,7 @@ where
|
||||
})
|
||||
});
|
||||
let after_f = f(&wrt);
|
||||
Differentiable::grad_once(after_f.contents, wrt)
|
||||
DifferentiableContents::grad_once(after_f.contents.contents, wrt)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -369,21 +479,18 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
fn extract_scalar<'a, A>(d: &'a Differentiable<A>) -> &'a A {
|
||||
match d {
|
||||
Differentiable::Scalar(a) => &(a.real_part()),
|
||||
Differentiable::Vector(_) => panic!("not a scalar"),
|
||||
}
|
||||
d.borrow_scalar().real_part()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map() {
|
||||
let v = Differentiable::Vector(
|
||||
let v = Differentiable::of_vec(
|
||||
vec![
|
||||
Differentiable::Scalar(Scalar::Number(
|
||||
Differentiable::of_scalar(Scalar::Number(
|
||||
NotNan::new(3.0).expect("3 is not NaN"),
|
||||
Some(0usize),
|
||||
)),
|
||||
Differentiable::Scalar(Scalar::Number(
|
||||
Differentiable::of_scalar(Scalar::Number(
|
||||
NotNan::new(4.0).expect("4 is not NaN"),
|
||||
Some(1usize),
|
||||
)),
|
||||
@@ -395,13 +502,11 @@ mod tests {
|
||||
Scalar::Dual(_, _) => panic!("Not hit"),
|
||||
});
|
||||
|
||||
let v = match mapped {
|
||||
Differentiable::Scalar(_) => panic!("Not a scalar"),
|
||||
Differentiable::Vector(v) => v
|
||||
.iter()
|
||||
.map(|d| extract_scalar(d).clone())
|
||||
.collect::<Vec<_>>(),
|
||||
};
|
||||
let v = mapped
|
||||
.into_vector()
|
||||
.iter()
|
||||
.map(|d| extract_scalar(d).clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
assert_eq!(v, [4.0, 5.0]);
|
||||
}
|
||||
|
@@ -126,7 +126,7 @@ where
|
||||
let dotted = RankedDifferentiable::of_scalar(
|
||||
dot_unranked(
|
||||
left_arg.to_unranked_borrow(),
|
||||
&Differentiable::Vector(theta.to_vec()),
|
||||
&Differentiable::of_vec(theta.to_vec()),
|
||||
)
|
||||
.into_vector()
|
||||
.into_iter()
|
||||
@@ -180,7 +180,7 @@ where
|
||||
);
|
||||
dot_unranked(
|
||||
x_powers.to_unranked_borrow(),
|
||||
&Differentiable::Vector(theta.to_vec()),
|
||||
&Differentiable::of_vec(theta.to_vec()),
|
||||
)
|
||||
.attach_rank::<1>()
|
||||
.expect("wanted a tensor1")
|
||||
|
@@ -106,7 +106,7 @@ fn main() {
|
||||
},
|
||||
[
|
||||
RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(),
|
||||
Differentiable::Scalar(Scalar::zero()),
|
||||
Differentiable::of_scalar(Scalar::zero()),
|
||||
],
|
||||
hyper.iterations,
|
||||
)
|
||||
@@ -168,7 +168,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn grad_example() {
|
||||
let input_vec = [Differentiable::Scalar(Scalar::make(
|
||||
let input_vec = [Differentiable::of_scalar(Scalar::make(
|
||||
NotNan::new(27.0).expect("not nan"),
|
||||
))];
|
||||
|
||||
@@ -362,7 +362,7 @@ mod tests {
|
||||
},
|
||||
[
|
||||
RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(),
|
||||
Differentiable::Scalar(Scalar::zero()),
|
||||
Differentiable::of_scalar(Scalar::zero()),
|
||||
],
|
||||
hyper.iterations,
|
||||
)
|
||||
|
Reference in New Issue
Block a user