From 1ee76d4bc378d835e83ccf6281ca8f519aa9366a Mon Sep 17 00:00:00 2001 From: Patrick Stevens Date: Mon, 8 May 2023 14:43:27 +0100 Subject: [PATCH] Relu (#24) --- little_learner/src/auto_diff.rs | 13 +++---- little_learner/src/decider.rs | 68 +++++++++++++++++++++++++++++++++ little_learner/src/lib.rs | 1 + little_learner/src/loss.rs | 19 ++++++++- little_learner/src/scalar.rs | 10 +++++ 5 files changed, 102 insertions(+), 9 deletions(-) create mode 100644 little_learner/src/decider.rs diff --git a/little_learner/src/auto_diff.rs b/little_learner/src/auto_diff.rs index e32511f..539362c 100644 --- a/little_learner/src/auto_diff.rs +++ b/little_learner/src/auto_diff.rs @@ -145,13 +145,12 @@ impl DifferentiableContents { fn map_tag(&self, f: &mut F) -> DifferentiableContents where - F: FnMut(Tag) -> Tag2, + F: FnMut(&Tag) -> Tag2, A: Clone, - Tag: Clone, { match self { DifferentiableContents::Scalar(a, tag) => { - DifferentiableContents::Scalar((*a).clone(), f((*tag).clone())) + DifferentiableContents::Scalar((*a).clone(), f(tag)) } DifferentiableContents::Vector(slice, rank) => { DifferentiableContents::Vector(slice.iter().map(|x| x.map_tag(f)).collect(), *rank) @@ -253,9 +252,8 @@ impl DifferentiableTagged { pub fn map_tag(&self, f: &mut F) -> DifferentiableTagged where - F: FnMut(Tag) -> Tag2, + F: FnMut(&Tag) -> Tag2, A: Clone, - Tag: Clone, { DifferentiableTagged { contents: self.contents.map_tag(f), @@ -572,13 +570,12 @@ impl RankedDifferentiableTagged { } pub fn map_tag( - self: RankedDifferentiableTagged, + self: &RankedDifferentiableTagged, f: &mut F, ) -> RankedDifferentiableTagged where A: Clone, - F: FnMut(Tag) -> Tag2, - Tag: Clone, + F: FnMut(&Tag) -> Tag2, { RankedDifferentiableTagged { contents: DifferentiableTagged::map_tag(&self.contents, f), diff --git a/little_learner/src/decider.rs b/little_learner/src/decider.rs new file mode 100644 index 0000000..9dcbe6c --- /dev/null +++ b/little_learner/src/decider.rs @@ -0,0 +1,68 @@ +use crate::auto_diff::RankedDifferentiableTagged; +use crate::loss::dot; +use crate::scalar::Scalar; +use crate::traits::{NumLike, Zero}; + +fn rectify(x: A) -> A +where + A: Zero + PartialOrd, +{ + if x < A::zero() { + A::zero() + } else { + x + } +} + +fn linear( + t: RankedDifferentiableTagged, + theta0: RankedDifferentiableTagged, + theta1: Scalar, +) -> Scalar +where + A: NumLike, +{ + dot(&theta0, &t) + theta1 +} + +pub fn relu( + t: RankedDifferentiableTagged, + theta0: RankedDifferentiableTagged, + theta1: Scalar, +) -> Scalar +where + A: NumLike + PartialOrd, +{ + rectify(linear(t, theta0, theta1)) +} + +#[cfg(test)] +mod test_decider { + use crate::auto_diff::RankedDifferentiable; + use crate::decider::{linear, relu}; + use crate::not_nan::to_not_nan_1; + use crate::scalar::Scalar; + use ordered_float::NotNan; + + #[test] + fn test_linear() { + let theta0 = RankedDifferentiable::of_slice(&to_not_nan_1([7.1, 4.3, -6.4])); + let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan")); + let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0])); + + let result = linear(t, theta0, theta1).real_part().into_inner(); + + assert!((result + 0.1).abs() < 0.000_000_01); + } + + #[test] + fn test_relu() { + let theta0 = RankedDifferentiable::of_slice(&to_not_nan_1([7.1, 4.3, -6.4])); + let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan")); + let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0])); + + let result = relu(t, theta0, theta1).real_part().into_inner(); + + assert_eq!(result, 0.0); + } +} diff --git a/little_learner/src/lib.rs b/little_learner/src/lib.rs index e0aa15e..35daaf9 100644 --- a/little_learner/src/lib.rs +++ b/little_learner/src/lib.rs @@ -3,6 +3,7 @@ #![feature(array_methods)] pub mod auto_diff; +pub mod decider; pub mod gradient_descent; pub mod hyper; pub mod loss; diff --git a/little_learner/src/loss.rs b/little_learner/src/loss.rs index 3e08e15..44d2eaf 100644 --- a/little_learner/src/loss.rs +++ b/little_learner/src/loss.rs @@ -3,7 +3,7 @@ use std::{ ops::{Add, Mul, Neg}, }; -use crate::auto_diff::Differentiable; +use crate::auto_diff::{Differentiable, RankedDifferentiableTagged}; use crate::{ auto_diff::{DifferentiableTagged, RankedDifferentiable}, scalar::Scalar, @@ -50,6 +50,23 @@ where dot_unranked_tagged(x, y, |(), ()| ()) } +pub fn dot( + x: &RankedDifferentiableTagged, + y: &RankedDifferentiableTagged, +) -> Scalar +where + A: Mul + Sum + Clone + Add + Zero, +{ + // Much sadness - find a way to get rid of these clones + let x = x.map_tag(&mut |_| ()); + let y = y.map_tag(&mut |_| ()); + x.to_vector() + .iter() + .zip(y.to_vector().iter()) + .map(|(x, y)| x.clone().to_scalar() * y.clone().to_scalar()) + .sum() +} + fn squared_2( x: &RankedDifferentiable, ) -> RankedDifferentiable diff --git a/little_learner/src/scalar.rs b/little_learner/src/scalar.rs index ebe5936..72ebc88 100644 --- a/little_learner/src/scalar.rs +++ b/little_learner/src/scalar.rs @@ -1,5 +1,6 @@ use crate::traits::{Exp, One, Sqrt, Zero}; use core::hash::Hash; +use std::cmp::Ordering; use std::{ collections::{hash_map::Entry, HashMap}, fmt::Display, @@ -237,6 +238,15 @@ where } } +impl PartialOrd for Scalar +where + A: PartialOrd + Clone, +{ + fn partial_cmp(&self, other: &Self) -> Option { + self.real_part().partial_cmp(other.real_part()) + } +} + impl Exp for Scalar where A: Exp + Clone,