Relu (#24)
This commit is contained in:
@@ -145,13 +145,12 @@ impl<A, Tag> DifferentiableContents<A, Tag> {
|
||||
|
||||
fn map_tag<Tag2, F>(&self, f: &mut F) -> DifferentiableContents<A, Tag2>
|
||||
where
|
||||
F: FnMut(Tag) -> Tag2,
|
||||
F: FnMut(&Tag) -> Tag2,
|
||||
A: Clone,
|
||||
Tag: Clone,
|
||||
{
|
||||
match self {
|
||||
DifferentiableContents::Scalar(a, tag) => {
|
||||
DifferentiableContents::Scalar((*a).clone(), f((*tag).clone()))
|
||||
DifferentiableContents::Scalar((*a).clone(), f(tag))
|
||||
}
|
||||
DifferentiableContents::Vector(slice, rank) => {
|
||||
DifferentiableContents::Vector(slice.iter().map(|x| x.map_tag(f)).collect(), *rank)
|
||||
@@ -253,9 +252,8 @@ impl<A, Tag> DifferentiableTagged<A, Tag> {
|
||||
|
||||
pub fn map_tag<Tag2, F>(&self, f: &mut F) -> DifferentiableTagged<A, Tag2>
|
||||
where
|
||||
F: FnMut(Tag) -> Tag2,
|
||||
F: FnMut(&Tag) -> Tag2,
|
||||
A: Clone,
|
||||
Tag: Clone,
|
||||
{
|
||||
DifferentiableTagged {
|
||||
contents: self.contents.map_tag(f),
|
||||
@@ -572,13 +570,12 @@ impl<A, Tag, const RANK: usize> RankedDifferentiableTagged<A, Tag, RANK> {
|
||||
}
|
||||
|
||||
pub fn map_tag<Tag2, F>(
|
||||
self: RankedDifferentiableTagged<A, Tag, RANK>,
|
||||
self: &RankedDifferentiableTagged<A, Tag, RANK>,
|
||||
f: &mut F,
|
||||
) -> RankedDifferentiableTagged<A, Tag2, RANK>
|
||||
where
|
||||
A: Clone,
|
||||
F: FnMut(Tag) -> Tag2,
|
||||
Tag: Clone,
|
||||
F: FnMut(&Tag) -> Tag2,
|
||||
{
|
||||
RankedDifferentiableTagged {
|
||||
contents: DifferentiableTagged::map_tag(&self.contents, f),
|
||||
|
68
little_learner/src/decider.rs
Normal file
68
little_learner/src/decider.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
use crate::auto_diff::RankedDifferentiableTagged;
|
||||
use crate::loss::dot;
|
||||
use crate::scalar::Scalar;
|
||||
use crate::traits::{NumLike, Zero};
|
||||
|
||||
fn rectify<A>(x: A) -> A
|
||||
where
|
||||
A: Zero + PartialOrd,
|
||||
{
|
||||
if x < A::zero() {
|
||||
A::zero()
|
||||
} else {
|
||||
x
|
||||
}
|
||||
}
|
||||
|
||||
fn linear<A, Tag1, Tag2>(
|
||||
t: RankedDifferentiableTagged<A, Tag1, 1>,
|
||||
theta0: RankedDifferentiableTagged<A, Tag2, 1>,
|
||||
theta1: Scalar<A>,
|
||||
) -> Scalar<A>
|
||||
where
|
||||
A: NumLike,
|
||||
{
|
||||
dot(&theta0, &t) + theta1
|
||||
}
|
||||
|
||||
pub fn relu<A, Tag1, Tag2>(
|
||||
t: RankedDifferentiableTagged<A, Tag1, 1>,
|
||||
theta0: RankedDifferentiableTagged<A, Tag2, 1>,
|
||||
theta1: Scalar<A>,
|
||||
) -> Scalar<A>
|
||||
where
|
||||
A: NumLike + PartialOrd,
|
||||
{
|
||||
rectify(linear(t, theta0, theta1))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_decider {
|
||||
use crate::auto_diff::RankedDifferentiable;
|
||||
use crate::decider::{linear, relu};
|
||||
use crate::not_nan::to_not_nan_1;
|
||||
use crate::scalar::Scalar;
|
||||
use ordered_float::NotNan;
|
||||
|
||||
#[test]
|
||||
fn test_linear() {
|
||||
let theta0 = RankedDifferentiable::of_slice(&to_not_nan_1([7.1, 4.3, -6.4]));
|
||||
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
|
||||
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
|
||||
|
||||
let result = linear(t, theta0, theta1).real_part().into_inner();
|
||||
|
||||
assert!((result + 0.1).abs() < 0.000_000_01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relu() {
|
||||
let theta0 = RankedDifferentiable::of_slice(&to_not_nan_1([7.1, 4.3, -6.4]));
|
||||
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
|
||||
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
|
||||
|
||||
let result = relu(t, theta0, theta1).real_part().into_inner();
|
||||
|
||||
assert_eq!(result, 0.0);
|
||||
}
|
||||
}
|
@@ -3,6 +3,7 @@
|
||||
#![feature(array_methods)]
|
||||
|
||||
pub mod auto_diff;
|
||||
pub mod decider;
|
||||
pub mod gradient_descent;
|
||||
pub mod hyper;
|
||||
pub mod loss;
|
||||
|
@@ -3,7 +3,7 @@ use std::{
|
||||
ops::{Add, Mul, Neg},
|
||||
};
|
||||
|
||||
use crate::auto_diff::Differentiable;
|
||||
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
|
||||
use crate::{
|
||||
auto_diff::{DifferentiableTagged, RankedDifferentiable},
|
||||
scalar::Scalar,
|
||||
@@ -50,6 +50,23 @@ where
|
||||
dot_unranked_tagged(x, y, |(), ()| ())
|
||||
}
|
||||
|
||||
pub fn dot<A, Tag1, Tag2>(
|
||||
x: &RankedDifferentiableTagged<A, Tag1, 1>,
|
||||
y: &RankedDifferentiableTagged<A, Tag2, 1>,
|
||||
) -> Scalar<A>
|
||||
where
|
||||
A: Mul<Output = A> + Sum + Clone + Add<Output = A> + Zero,
|
||||
{
|
||||
// Much sadness - find a way to get rid of these clones
|
||||
let x = x.map_tag(&mut |_| ());
|
||||
let y = y.map_tag(&mut |_| ());
|
||||
x.to_vector()
|
||||
.iter()
|
||||
.zip(y.to_vector().iter())
|
||||
.map(|(x, y)| x.clone().to_scalar() * y.clone().to_scalar())
|
||||
.sum()
|
||||
}
|
||||
|
||||
fn squared_2<A, const RANK: usize>(
|
||||
x: &RankedDifferentiable<A, RANK>,
|
||||
) -> RankedDifferentiable<A, RANK>
|
||||
|
@@ -1,5 +1,6 @@
|
||||
use crate::traits::{Exp, One, Sqrt, Zero};
|
||||
use core::hash::Hash;
|
||||
use std::cmp::Ordering;
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
fmt::Display,
|
||||
@@ -237,6 +238,15 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> PartialOrd for Scalar<A>
|
||||
where
|
||||
A: PartialOrd + Clone,
|
||||
{
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.real_part().partial_cmp(other.real_part())
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Exp for Scalar<A>
|
||||
where
|
||||
A: Exp + Clone,
|
||||
|
Reference in New Issue
Block a user