diff --git a/Cargo.lock b/Cargo.lock
index 1e333c7..c30cc64 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -55,7 +55,6 @@ dependencies = [
name = "little_learner_app"
version = "0.1.0"
dependencies = [
- "arrayvec",
"immutable-chunkmap",
"little_learner",
"ordered-float",
diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs
index 0b50b4e..97d59b5 100644
--- a/little_learner/src/auto_diff.rs
+++ b/little_learner/src/auto_diff.rs
@@ -7,12 +7,12 @@ use std::{
ops::{AddAssign, Div, Mul, Neg},
};
-impl Zero for DifferentiableHidden
+impl Zero for Differentiable
where
A: Zero,
{
- fn zero() -> DifferentiableHidden {
- DifferentiableHidden::Scalar(Scalar::Number(A::zero(), None))
+ fn zero() -> Differentiable {
+ Differentiable::Scalar(Scalar::Number(A::zero(), None))
}
}
@@ -25,16 +25,16 @@ where
}
}
-impl One for DifferentiableHidden
+impl One for Differentiable
where
A: One,
{
- fn one() -> DifferentiableHidden {
- DifferentiableHidden::Scalar(Scalar::one())
+ fn one() -> Differentiable {
+ Differentiable::Scalar(Scalar::one())
}
}
-impl Clone for DifferentiableHidden
+impl Clone for Differentiable
where
A: Clone,
{
@@ -47,19 +47,19 @@ where
}
#[derive(Debug)]
-enum DifferentiableHidden {
+pub enum Differentiable {
Scalar(Scalar),
- Vector(Vec>),
+ Vector(Vec>),
}
-impl Display for DifferentiableHidden
+impl Display for Differentiable
where
A: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
- DifferentiableHidden::Scalar(s) => f.write_fmt(format_args!("{}", s)),
- DifferentiableHidden::Vector(v) => {
+ Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)),
+ Differentiable::Vector(v) => {
f.write_char('[')?;
for v in v.iter() {
f.write_fmt(format_args!("{}", v))?;
@@ -71,32 +71,32 @@ where
}
}
-impl DifferentiableHidden {
- fn map(&self, f: &mut F) -> DifferentiableHidden
+impl Differentiable {
+ pub fn map(&self, f: &mut F) -> Differentiable
where
F: FnMut(Scalar) -> Scalar,
A: Clone,
{
match self {
- DifferentiableHidden::Scalar(a) => DifferentiableHidden::Scalar(f(a.clone())),
- DifferentiableHidden::Vector(slice) => {
- DifferentiableHidden::Vector(slice.iter().map(|x| x.map(f)).collect())
+ Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())),
+ Differentiable::Vector(slice) => {
+ Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect())
}
}
}
- fn map2(&self, other: &DifferentiableHidden, f: &F) -> DifferentiableHidden
+ pub fn map2(&self, other: &Differentiable, f: &F) -> Differentiable
where
F: Fn(&Scalar, &Scalar) -> Scalar,
A: Clone,
B: Clone,
{
match (self, other) {
- (DifferentiableHidden::Scalar(a), DifferentiableHidden::Scalar(b)) => {
- DifferentiableHidden::Scalar(f(a, b))
+ (Differentiable::Scalar(a), Differentiable::Scalar(b)) => {
+ Differentiable::Scalar(f(a, b))
}
- (DifferentiableHidden::Vector(slice_a), DifferentiableHidden::Vector(slice_b)) => {
- DifferentiableHidden::Vector(
+ (Differentiable::Vector(slice_a), Differentiable::Vector(slice_b)) => {
+ Differentiable::Vector(
slice_a
.iter()
.zip(slice_b.iter())
@@ -108,20 +108,69 @@ impl DifferentiableHidden {
}
}
- fn of_slice(input: &[A]) -> DifferentiableHidden
+ fn of_slice(input: T) -> Differentiable
where
A: Clone,
+ T: AsRef<[A]>,
{
- DifferentiableHidden::Vector(
+ Differentiable::Vector(
input
+ .as_ref()
.iter()
- .map(|v| DifferentiableHidden::Scalar(Scalar::Number((*v).clone(), None)))
+ .map(|v| Differentiable::Scalar(Scalar::Number((*v).clone(), None)))
.collect(),
)
}
+
+ pub fn rank(&self) -> usize {
+ match self {
+ Differentiable::Scalar(_) => 0,
+ Differentiable::Vector(v) => v[0].rank() + 1,
+ }
+ }
+
+ pub fn attach_rank(
+ self: Differentiable,
+ ) -> Option> {
+ if self.rank() == RANK {
+ Some(RankedDifferentiable { contents: self })
+ } else {
+ None
+ }
+ }
}
-impl DifferentiableHidden
+impl Differentiable {
+ pub fn into_scalar(self) -> Scalar {
+ match self {
+ Differentiable::Scalar(s) => s,
+ Differentiable::Vector(_) => panic!("not a scalar"),
+ }
+ }
+
+ pub fn into_vector(self) -> Vec> {
+ match self {
+ Differentiable::Scalar(_) => panic!("not a vector"),
+ Differentiable::Vector(v) => v,
+ }
+ }
+
+ pub fn borrow_scalar(&self) -> &Scalar {
+ match self {
+ Differentiable::Scalar(s) => s,
+ Differentiable::Vector(_) => panic!("not a scalar"),
+ }
+ }
+
+ pub fn borrow_vector(&self) -> &Vec> {
+ match self {
+ Differentiable::Scalar(_) => panic!("not a vector"),
+ Differentiable::Vector(v) => v,
+ }
+ }
+}
+
+impl Differentiable
where
A: Clone
+ Eq
@@ -134,7 +183,7 @@ where
+ One
+ Neg