diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs
index 539362c..cae7c0b 100644
--- a/little_learner/src/auto_diff.rs
+++ b/little_learner/src/auto_diff.rs
@@ -69,7 +69,7 @@ where
#[derive(Debug)]
enum DifferentiableContents {
Scalar(Scalar, Tag),
- // Contains the rank.
+ // Contains the rank of this differentiable (i.e. one more than the rank of the inputs).
Vector(Vec>, usize),
}
@@ -199,6 +199,64 @@ impl DifferentiableContents {
}
}
+ /// Unwraps one layer of each input, so the passed function takes inputs which have decreased
+ /// the ranks of the `map2_once_tagged` input by one.
+ /// Panics if passed a scalar or if the input vectors are not the same length.
+ pub fn map2_once_tagged(
+ self: &DifferentiableContents,
+ other: &DifferentiableContents,
+ mut f: F,
+ ) -> DifferentiableContents
+ where
+ F: FnMut(
+ &DifferentiableTagged,
+ &DifferentiableTagged,
+ ) -> DifferentiableTagged,
+ {
+ match (self, other) {
+ (DifferentiableContents::Scalar(_, _), _) => {
+ panic!("First arg needed to have non-scalar rank")
+ }
+ (_, DifferentiableContents::Scalar(_, _)) => {
+ panic!("Second arg needed to have non-scalar rank")
+ }
+ (
+ DifferentiableContents::Vector(v1, rank1),
+ DifferentiableContents::Vector(v2, _rank2),
+ ) => {
+ assert_eq!(
+ v1.len(),
+ v2.len(),
+ "Must map two vectors of the same length, got {rank1} and {_rank2}"
+ );
+ assert_ne!(
+ v1.len(),
+ 0,
+ "Cannot determine a rank of a zero-length vector"
+ );
+ let mut rank = 0usize;
+ DifferentiableContents::Vector(
+ v1.iter()
+ .zip(v2.iter())
+ .map(|(a, b)| {
+ let result = f(a, b);
+ match result.contents {
+ DifferentiableContents::Vector(_, discovered_rank) => {
+ rank = discovered_rank + 1;
+ }
+ DifferentiableContents::Scalar(_, _) => {
+ rank = 1;
+ }
+ }
+ result
+ })
+ .collect(),
+ rank,
+ )
+ }
+ }
+ }
+
fn of_slice<'a, T, I>(tag: Tag, input: I) -> DifferentiableContents
where
T: Clone + 'a,
@@ -277,6 +335,22 @@ impl DifferentiableTagged {
}
}
+ pub fn map2_once_tagged(
+ self: &DifferentiableTagged,
+ other: &DifferentiableTagged,
+ f: F,
+ ) -> DifferentiableTagged
+ where
+ F: FnMut(
+ &DifferentiableTagged,
+ &DifferentiableTagged,
+ ) -> DifferentiableTagged,
+ {
+ DifferentiableTagged {
+ contents: self.contents.map2_once_tagged(&other.contents, f),
+ }
+ }
+
pub fn attach_rank(
self: DifferentiableTagged,
) -> Option> {
@@ -582,10 +656,10 @@ impl RankedDifferentiableTagged {
}
}
- pub fn map2_tagged(
- self: &RankedDifferentiableTagged,
- other: &RankedDifferentiableTagged,
- f: &mut F,
+ pub fn map2_tagged<'a, 'b, B, C, Tag2, Tag3, F>(
+ self: &'a RankedDifferentiableTagged,
+ other: &'a RankedDifferentiableTagged,
+ f: &'b mut F,
) -> RankedDifferentiableTagged
where
F: FnMut(&Scalar, Tag, &Scalar, Tag2) -> (Scalar, Tag3),
@@ -598,6 +672,44 @@ impl RankedDifferentiableTagged {
contents: DifferentiableTagged::map2_tagged(&self.contents, &other.contents, f),
}
}
+ pub fn map2_once_tagged<
+ 'a,
+ 'c,
+ B,
+ C: 'a,
+ Tag2,
+ Tag3: 'a,
+ F,
+ const RANK_B: usize,
+ const RANK_OUT: usize,
+ >(
+ self: &'a RankedDifferentiableTagged,
+ other: &'a RankedDifferentiableTagged,
+ f: &'c mut F,
+ ) -> RankedDifferentiableTagged
+ where
+ F: FnMut(
+ &RankedDifferentiableTagged,
+ &RankedDifferentiableTagged,
+ ) -> RankedDifferentiableTagged,
+ A: Clone,
+ B: Clone,
+ Tag: Clone,
+ Tag2: Clone,
+ 'c: 'a,
+ {
+ RankedDifferentiableTagged {
+ contents: DifferentiableTagged::map2_once_tagged(
+ &self.contents,
+ &other.contents,
+ &mut |a: &DifferentiableTagged, b: &DifferentiableTagged| {
+ let a = (*a).clone().attach_rank::<{ RANK - 1 }>().unwrap();
+ let b = (*b).clone().attach_rank::<{ RANK_B - 1 }>().unwrap();
+ f(&a, &b).to_unranked()
+ },
+ ),
+ }
+ }
pub fn to_vector(
self: RankedDifferentiableTagged,
@@ -634,6 +746,22 @@ impl RankedDifferentiable {
{
self.map2_tagged(other, &mut |a, (), b, ()| (f(a, b), ()))
}
+
+ pub fn map2_once(
+ self: &RankedDifferentiable,
+ other: &RankedDifferentiable,
+ f: &mut F,
+ ) -> RankedDifferentiable
+ where
+ F: FnMut(
+ &RankedDifferentiable,
+ &RankedDifferentiable,
+ ) -> RankedDifferentiable,
+ A: Clone,
+ B: Clone,
+ {
+ self.map2_once_tagged(other, f)
+ }
}
pub fn grad(
diff --git a/little_learner/src/decider.rs b/little_learner/src/decider.rs
index 9dcbe6c..c979fb9 100644
--- a/little_learner/src/decider.rs
+++ b/little_learner/src/decider.rs
@@ -15,19 +15,19 @@ where
}
fn linear(
- t: RankedDifferentiableTagged,
- theta0: RankedDifferentiableTagged,
+ t: &RankedDifferentiableTagged,
+ theta0: &RankedDifferentiableTagged,
theta1: Scalar,
) -> Scalar
where
A: NumLike,
{
- dot(&theta0, &t) + theta1
+ dot(theta0, t) + theta1
}
pub fn relu(
- t: RankedDifferentiableTagged,
- theta0: RankedDifferentiableTagged,
+ t: &RankedDifferentiableTagged,
+ theta0: &RankedDifferentiableTagged,
theta1: Scalar,
) -> Scalar
where
@@ -50,7 +50,7 @@ mod test_decider {
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
- let result = linear(t, theta0, theta1).real_part().into_inner();
+ let result = linear(&t, &theta0, theta1).real_part().into_inner();
assert!((result + 0.1).abs() < 0.000_000_01);
}
@@ -61,7 +61,7 @@ mod test_decider {
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
- let result = relu(t, theta0, theta1).real_part().into_inner();
+ let result = relu(&t, &theta0, theta1).real_part().into_inner();
assert_eq!(result, 0.0);
}
diff --git a/little_learner/src/layer.rs b/little_learner/src/layer.rs
new file mode 100644
index 0000000..18117e2
--- /dev/null
+++ b/little_learner/src/layer.rs
@@ -0,0 +1,75 @@
+use crate::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
+use crate::decider::relu;
+use crate::traits::NumLike;
+
+/// Returns a tensor1.
+/// Theta has two components: a tensor2 of weights and a tensor1 of bias.
+pub fn layer(
+ theta: Differentiable,
+ t: RankedDifferentiable,
+) -> RankedDifferentiable
+where
+ T: NumLike + PartialOrd,
+{
+ let mut theta = theta.into_vector();
+ assert_eq!(theta.len(), 2, "Needed weights and a bias");
+ let b = theta.pop().unwrap().attach_rank::<1>().unwrap();
+ let w = theta.pop().unwrap().attach_rank::<2>().unwrap();
+
+ RankedDifferentiableTagged::map2_once(
+ &w,
+ &b,
+ &mut |w: &RankedDifferentiable<_, 1>, b: &RankedDifferentiable<_, 0>| {
+ RankedDifferentiableTagged::of_scalar(relu(&t, w, b.clone().to_scalar()))
+ },
+ )
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::auto_diff::{Differentiable, RankedDifferentiable};
+ use crate::layer::layer;
+ use crate::not_nan::{to_not_nan_1, to_not_nan_2};
+
+ #[test]
+ fn test_single_layer() {
+ let b = RankedDifferentiable::of_slice(&to_not_nan_1([1.0, 2.0]));
+ let w = RankedDifferentiable::of_slice_2::<_, 2>(&to_not_nan_2([
+ [3.0, 4.0, 5.0],
+ [6.0, 7.0, 8.0],
+ ]));
+ let theta = Differentiable::of_vec(vec![w.to_unranked(), b.to_unranked()]);
+
+ /*
+ Two neurons:
+ w =
+ (3 4 5
+ 6 7 8)
+ b = (1, 2)
+
+ Three inputs:
+ t = (9, 10, 11)
+
+ Output has two elements, one per neuron.
+ Neuron 1 has weights (3,4,5) and bias 1;
+ Neuron 2 has weights (6,7,8) and bias 2.
+
+ Neuron 1 is relu(t, (3,4,5), 1), which is (9, 10, 11).(3, 4, 5) + 1.
+ Neuron 2 is relu(t, (6,7,8), 2), which is (9, 10, 11).(6, 7, 8) + 2.
+ */
+
+ let t = RankedDifferentiable::of_slice(&to_not_nan_1([9.0, 10.0, 11.0]));
+ let mut output = layer(theta, t)
+ .to_vector()
+ .iter()
+ .map(|t| (*t).clone().to_scalar().clone_real_part().into_inner())
+ .collect::>();
+
+ assert_eq!(output.len(), 2);
+ let result_2 = output.pop().unwrap();
+ let result_1 = output.pop().unwrap();
+
+ assert_eq!(result_1, (9 * 3 + 10 * 4 + 11 * 5 + 1) as f64);
+ assert_eq!(result_2, (9 * 6 + 10 * 7 + 11 * 8 + 2) as f64);
+ }
+}
diff --git a/little_learner/src/lib.rs b/little_learner/src/lib.rs
index 35daaf9..73e528b 100644
--- a/little_learner/src/lib.rs
+++ b/little_learner/src/lib.rs
@@ -6,6 +6,7 @@ pub mod auto_diff;
pub mod decider;
pub mod gradient_descent;
pub mod hyper;
+pub mod layer;
pub mod loss;
pub mod not_nan;
pub mod predictor;
diff --git a/little_learner/src/not_nan.rs b/little_learner/src/not_nan.rs
index 02b698b..416235d 100644
--- a/little_learner/src/not_nan.rs
+++ b/little_learner/src/not_nan.rs
@@ -7,6 +7,10 @@ where
xs.map(|x| NotNan::new(x).expect("not nan"))
}
+pub fn from_not_nan_1(xs: [NotNan; N]) -> [T; N] {
+ xs.map(|x| x.into_inner())
+}
+
pub fn to_not_nan_2(xs: [[T; N]; M]) -> [[NotNan; N]; M]
where
T: ordered_float::Float,