diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs
index f305936..de60ff8 100644
--- a/little_learner/src/auto_diff.rs
+++ b/little_learner/src/auto_diff.rs
@@ -12,7 +12,7 @@ where
A: Zero,
{
fn zero() -> DifferentiableHidden {
- DifferentiableHidden::Scalar(Scalar::Number(A::zero()))
+ DifferentiableHidden::Scalar(Scalar::Number(A::zero(), None))
}
}
@@ -21,7 +21,7 @@ where
A: One,
{
fn one() -> Scalar {
- Scalar::Number(A::one())
+ Scalar::Number(A::one(), None)
}
}
@@ -46,6 +46,7 @@ where
}
}
+#[derive(Debug)]
enum DifferentiableHidden {
Scalar(Scalar),
Vector(Vec>),
@@ -71,9 +72,9 @@ where
}
impl DifferentiableHidden {
- fn map(&self, f: &F) -> DifferentiableHidden
+ fn map(&self, f: &mut F) -> DifferentiableHidden
where
- F: Fn(Scalar) -> Scalar,
+ F: FnMut(Scalar) -> Scalar,
A: Clone,
{
match self {
@@ -114,7 +115,7 @@ impl DifferentiableHidden {
DifferentiableHidden::Vector(
input
.iter()
- .map(|v| DifferentiableHidden::Scalar(Scalar::Number((*v).clone())))
+ .map(|v| DifferentiableHidden::Scalar(Scalar::Number((*v).clone(), None)))
.collect(),
)
}
@@ -131,7 +132,8 @@ where
+ Div