diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs
index e32511f..539362c 100644
--- a/little_learner/src/auto_diff.rs
+++ b/little_learner/src/auto_diff.rs
@@ -145,13 +145,12 @@ impl DifferentiableContents {
fn map_tag(&self, f: &mut F) -> DifferentiableContents
where
- F: FnMut(Tag) -> Tag2,
+ F: FnMut(&Tag) -> Tag2,
A: Clone,
- Tag: Clone,
{
match self {
DifferentiableContents::Scalar(a, tag) => {
- DifferentiableContents::Scalar((*a).clone(), f((*tag).clone()))
+ DifferentiableContents::Scalar((*a).clone(), f(tag))
}
DifferentiableContents::Vector(slice, rank) => {
DifferentiableContents::Vector(slice.iter().map(|x| x.map_tag(f)).collect(), *rank)
@@ -253,9 +252,8 @@ impl DifferentiableTagged {
pub fn map_tag(&self, f: &mut F) -> DifferentiableTagged
where
- F: FnMut(Tag) -> Tag2,
+ F: FnMut(&Tag) -> Tag2,
A: Clone,
- Tag: Clone,
{
DifferentiableTagged {
contents: self.contents.map_tag(f),
@@ -572,13 +570,12 @@ impl RankedDifferentiableTagged {
}
pub fn map_tag(
- self: RankedDifferentiableTagged,
+ self: &RankedDifferentiableTagged,
f: &mut F,
) -> RankedDifferentiableTagged
where
A: Clone,
- F: FnMut(Tag) -> Tag2,
- Tag: Clone,
+ F: FnMut(&Tag) -> Tag2,
{
RankedDifferentiableTagged {
contents: DifferentiableTagged::map_tag(&self.contents, f),
diff --git a/little_learner/src/decider.rs b/little_learner/src/decider.rs
new file mode 100644
index 0000000..9dcbe6c
--- /dev/null
+++ b/little_learner/src/decider.rs
@@ -0,0 +1,68 @@
+use crate::auto_diff::RankedDifferentiableTagged;
+use crate::loss::dot;
+use crate::scalar::Scalar;
+use crate::traits::{NumLike, Zero};
+
+fn rectify(x: A) -> A
+where
+ A: Zero + PartialOrd,
+{
+ if x < A::zero() {
+ A::zero()
+ } else {
+ x
+ }
+}
+
+fn linear(
+ t: RankedDifferentiableTagged,
+ theta0: RankedDifferentiableTagged,
+ theta1: Scalar,
+) -> Scalar
+where
+ A: NumLike,
+{
+ dot(&theta0, &t) + theta1
+}
+
+pub fn relu(
+ t: RankedDifferentiableTagged,
+ theta0: RankedDifferentiableTagged,
+ theta1: Scalar,
+) -> Scalar
+where
+ A: NumLike + PartialOrd,
+{
+ rectify(linear(t, theta0, theta1))
+}
+
+#[cfg(test)]
+mod test_decider {
+ use crate::auto_diff::RankedDifferentiable;
+ use crate::decider::{linear, relu};
+ use crate::not_nan::to_not_nan_1;
+ use crate::scalar::Scalar;
+ use ordered_float::NotNan;
+
+ #[test]
+ fn test_linear() {
+ let theta0 = RankedDifferentiable::of_slice(&to_not_nan_1([7.1, 4.3, -6.4]));
+ let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
+ let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
+
+ let result = linear(t, theta0, theta1).real_part().into_inner();
+
+ assert!((result + 0.1).abs() < 0.000_000_01);
+ }
+
+ #[test]
+ fn test_relu() {
+ let theta0 = RankedDifferentiable::of_slice(&to_not_nan_1([7.1, 4.3, -6.4]));
+ let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
+ let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
+
+ let result = relu(t, theta0, theta1).real_part().into_inner();
+
+ assert_eq!(result, 0.0);
+ }
+}
diff --git a/little_learner/src/lib.rs b/little_learner/src/lib.rs
index e0aa15e..35daaf9 100644
--- a/little_learner/src/lib.rs
+++ b/little_learner/src/lib.rs
@@ -3,6 +3,7 @@
#![feature(array_methods)]
pub mod auto_diff;
+pub mod decider;
pub mod gradient_descent;
pub mod hyper;
pub mod loss;
diff --git a/little_learner/src/loss.rs b/little_learner/src/loss.rs
index 3e08e15..44d2eaf 100644
--- a/little_learner/src/loss.rs
+++ b/little_learner/src/loss.rs
@@ -3,7 +3,7 @@ use std::{
ops::{Add, Mul, Neg},
};
-use crate::auto_diff::Differentiable;
+use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
use crate::{
auto_diff::{DifferentiableTagged, RankedDifferentiable},
scalar::Scalar,
@@ -50,6 +50,23 @@ where
dot_unranked_tagged(x, y, |(), ()| ())
}
+pub fn dot(
+ x: &RankedDifferentiableTagged,
+ y: &RankedDifferentiableTagged,
+) -> Scalar
+where
+ A: Mul