This commit is contained in:
Patrick Stevens
2023-05-08 14:43:27 +01:00
committed by GitHub
parent fac93253f2
commit 1ee76d4bc3
5 changed files with 102 additions and 9 deletions

View File

@@ -145,13 +145,12 @@ impl<A, Tag> DifferentiableContents<A, Tag> {
fn map_tag<Tag2, F>(&self, f: &mut F) -> DifferentiableContents<A, Tag2>
where
F: FnMut(Tag) -> Tag2,
F: FnMut(&Tag) -> Tag2,
A: Clone,
Tag: Clone,
{
match self {
DifferentiableContents::Scalar(a, tag) => {
DifferentiableContents::Scalar((*a).clone(), f((*tag).clone()))
DifferentiableContents::Scalar((*a).clone(), f(tag))
}
DifferentiableContents::Vector(slice, rank) => {
DifferentiableContents::Vector(slice.iter().map(|x| x.map_tag(f)).collect(), *rank)
@@ -253,9 +252,8 @@ impl<A, Tag> DifferentiableTagged<A, Tag> {
pub fn map_tag<Tag2, F>(&self, f: &mut F) -> DifferentiableTagged<A, Tag2>
where
F: FnMut(Tag) -> Tag2,
F: FnMut(&Tag) -> Tag2,
A: Clone,
Tag: Clone,
{
DifferentiableTagged {
contents: self.contents.map_tag(f),
@@ -572,13 +570,12 @@ impl<A, Tag, const RANK: usize> RankedDifferentiableTagged<A, Tag, RANK> {
}
pub fn map_tag<Tag2, F>(
self: RankedDifferentiableTagged<A, Tag, RANK>,
self: &RankedDifferentiableTagged<A, Tag, RANK>,
f: &mut F,
) -> RankedDifferentiableTagged<A, Tag2, RANK>
where
A: Clone,
F: FnMut(Tag) -> Tag2,
Tag: Clone,
F: FnMut(&Tag) -> Tag2,
{
RankedDifferentiableTagged {
contents: DifferentiableTagged::map_tag(&self.contents, f),

View File

@@ -0,0 +1,68 @@
use crate::auto_diff::RankedDifferentiableTagged;
use crate::loss::dot;
use crate::scalar::Scalar;
use crate::traits::{NumLike, Zero};
fn rectify<A>(x: A) -> A
where
A: Zero + PartialOrd,
{
if x < A::zero() {
A::zero()
} else {
x
}
}
fn linear<A, Tag1, Tag2>(
t: RankedDifferentiableTagged<A, Tag1, 1>,
theta0: RankedDifferentiableTagged<A, Tag2, 1>,
theta1: Scalar<A>,
) -> Scalar<A>
where
A: NumLike,
{
dot(&theta0, &t) + theta1
}
pub fn relu<A, Tag1, Tag2>(
t: RankedDifferentiableTagged<A, Tag1, 1>,
theta0: RankedDifferentiableTagged<A, Tag2, 1>,
theta1: Scalar<A>,
) -> Scalar<A>
where
A: NumLike + PartialOrd,
{
rectify(linear(t, theta0, theta1))
}
#[cfg(test)]
mod test_decider {
use crate::auto_diff::RankedDifferentiable;
use crate::decider::{linear, relu};
use crate::not_nan::to_not_nan_1;
use crate::scalar::Scalar;
use ordered_float::NotNan;
#[test]
fn test_linear() {
let theta0 = RankedDifferentiable::of_slice(&to_not_nan_1([7.1, 4.3, -6.4]));
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
let result = linear(t, theta0, theta1).real_part().into_inner();
assert!((result + 0.1).abs() < 0.000_000_01);
}
#[test]
fn test_relu() {
let theta0 = RankedDifferentiable::of_slice(&to_not_nan_1([7.1, 4.3, -6.4]));
let theta1 = Scalar::make(NotNan::new(0.6).expect("not nan"));
let t = RankedDifferentiable::of_slice(&to_not_nan_1([2.0, 1.0, 3.0]));
let result = relu(t, theta0, theta1).real_part().into_inner();
assert_eq!(result, 0.0);
}
}

View File

@@ -3,6 +3,7 @@
#![feature(array_methods)]
pub mod auto_diff;
pub mod decider;
pub mod gradient_descent;
pub mod hyper;
pub mod loss;

View File

@@ -3,7 +3,7 @@ use std::{
ops::{Add, Mul, Neg},
};
use crate::auto_diff::Differentiable;
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
use crate::{
auto_diff::{DifferentiableTagged, RankedDifferentiable},
scalar::Scalar,
@@ -50,6 +50,23 @@ where
dot_unranked_tagged(x, y, |(), ()| ())
}
pub fn dot<A, Tag1, Tag2>(
x: &RankedDifferentiableTagged<A, Tag1, 1>,
y: &RankedDifferentiableTagged<A, Tag2, 1>,
) -> Scalar<A>
where
A: Mul<Output = A> + Sum + Clone + Add<Output = A> + Zero,
{
// Much sadness - find a way to get rid of these clones
let x = x.map_tag(&mut |_| ());
let y = y.map_tag(&mut |_| ());
x.to_vector()
.iter()
.zip(y.to_vector().iter())
.map(|(x, y)| x.clone().to_scalar() * y.clone().to_scalar())
.sum()
}
fn squared_2<A, const RANK: usize>(
x: &RankedDifferentiable<A, RANK>,
) -> RankedDifferentiable<A, RANK>

View File

@@ -1,5 +1,6 @@
use crate::traits::{Exp, One, Sqrt, Zero};
use core::hash::Hash;
use std::cmp::Ordering;
use std::{
collections::{hash_map::Entry, HashMap},
fmt::Display,
@@ -237,6 +238,15 @@ where
}
}
impl<A> PartialOrd for Scalar<A>
where
A: PartialOrd + Clone,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.real_part().partial_cmp(other.real_part())
}
}
impl<A> Exp for Scalar<A>
where
A: Exp + Clone,