diff --git a/little_learner/src/decider.rs b/little_learner/src/decider.rs
index c979fb9..9b97a0b 100644
--- a/little_learner/src/decider.rs
+++ b/little_learner/src/decider.rs
@@ -3,7 +3,7 @@ use crate::loss::dot;
use crate::scalar::Scalar;
use crate::traits::{NumLike, Zero};
-fn rectify(x: A) -> A
+pub(crate) fn rectify(x: A) -> A
where
A: Zero + PartialOrd,
{
diff --git a/little_learner/src/ext.rs b/little_learner/src/ext.rs
index e6e60b8..c9a9207 100644
--- a/little_learner/src/ext.rs
+++ b/little_learner/src/ext.rs
@@ -1,6 +1,11 @@
-use crate::auto_diff::{Differentiable, DifferentiableTagged, RankedDifferentiable};
+use crate::auto_diff::{
+ Differentiable, DifferentiableTagged, RankedDifferentiable, RankedDifferentiableTagged,
+};
+use crate::decider::rectify;
+use crate::scalar::Scalar;
+use crate::traits::{NumLike, Zero};
use std::iter::Sum;
-use std::ops::Mul;
+use std::ops::{Add, Mul};
pub fn ext1(
n: usize,
@@ -49,17 +54,21 @@ where
}
}
-pub fn elementwise_mul_via_ext(
- x: &RankedDifferentiable,
- y: &RankedDifferentiable,
+pub fn elementwise_mul_via_ext(
+ x: &RankedDifferentiableTagged,
+ y: &RankedDifferentiableTagged,
) -> RankedDifferentiable
where
A: Mul