Add rank parameters to autodiff (#6)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,3 +2,4 @@ target/
|
|||||||
.idea/
|
.idea/
|
||||||
*.iml
|
*.iml
|
||||||
.vscode/
|
.vscode/
|
||||||
|
.profile*
|
||||||
|
@@ -37,8 +37,8 @@
|
|||||||
# Because rust-overlay bundles multiple rust packages into one
|
# Because rust-overlay bundles multiple rust packages into one
|
||||||
# derivation, specify that mega-bundle here, so that crate2nix
|
# derivation, specify that mega-bundle here, so that crate2nix
|
||||||
# will use them automatically.
|
# will use them automatically.
|
||||||
rustc = self.rust-bin.stable.latest.default;
|
rustc = self.rust-bin.nightly.latest.default;
|
||||||
cargo = self.rust-bin.stable.latest.default;
|
cargo = self.rust-bin.nightly.latest.default;
|
||||||
})
|
})
|
||||||
];
|
];
|
||||||
};
|
};
|
||||||
|
1
little_learner/rust-toolchain
Normal file
1
little_learner/rust-toolchain
Normal file
@@ -0,0 +1 @@
|
|||||||
|
nightly
|
@@ -1,265 +1,64 @@
|
|||||||
|
use crate::scalar::Scalar;
|
||||||
|
use crate::traits::{Exp, One, Zero};
|
||||||
use core::hash::Hash;
|
use core::hash::Hash;
|
||||||
use ordered_float::NotNan;
|
use std::collections::HashMap;
|
||||||
use std::{
|
use std::{
|
||||||
collections::{hash_map::Entry, HashMap},
|
|
||||||
fmt::{Display, Write},
|
fmt::{Display, Write},
|
||||||
ops::{Add, AddAssign, Div, Mul},
|
ops::{AddAssign, Div, Mul, Neg},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub trait Zero {
|
impl<A> Zero for DifferentiableHidden<A>
|
||||||
fn zero() -> Self;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub trait One {
|
|
||||||
fn one() -> Self;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Zero for f64 {
|
|
||||||
fn zero() -> Self {
|
|
||||||
0.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl One for f64 {
|
|
||||||
fn one() -> Self {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Zero for NotNan<f64> {
|
|
||||||
fn zero() -> Self {
|
|
||||||
NotNan::new(0.0).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl One for NotNan<f64> {
|
|
||||||
fn one() -> Self {
|
|
||||||
NotNan::new(1.0).unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<A> Zero for Differentiable<A>
|
|
||||||
where
|
where
|
||||||
A: Zero,
|
A: Zero,
|
||||||
{
|
{
|
||||||
fn zero() -> Differentiable<A> {
|
fn zero() -> DifferentiableHidden<A> {
|
||||||
Differentiable::Scalar(Scalar::Number(A::zero()))
|
DifferentiableHidden::Scalar(Scalar::Number(A::zero()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A> One for Differentiable<A>
|
impl<A> One for Scalar<A>
|
||||||
where
|
where
|
||||||
A: One,
|
A: One,
|
||||||
{
|
{
|
||||||
fn one() -> Differentiable<A> {
|
fn one() -> Scalar<A> {
|
||||||
Differentiable::Scalar(Scalar::Number(A::one()))
|
Scalar::Number(A::one())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Exp {
|
impl<A> One for DifferentiableHidden<A>
|
||||||
fn exp(self) -> Self;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Exp for NotNan<f64> {
|
|
||||||
fn exp(self) -> Self {
|
|
||||||
NotNan::new(f64::exp(self.into_inner())).expect("expected a non-NaN")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
|
||||||
pub enum LinkData<A> {
|
|
||||||
Addition(Box<Scalar<A>>, Box<Scalar<A>>),
|
|
||||||
Mul(Box<Scalar<A>>, Box<Scalar<A>>),
|
|
||||||
Exponent(Box<Scalar<A>>),
|
|
||||||
Log(Box<Scalar<A>>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
|
||||||
pub enum Link<A> {
|
|
||||||
EndOfLink,
|
|
||||||
Link(LinkData<A>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<A> Display for Link<A>
|
|
||||||
where
|
where
|
||||||
A: Display,
|
A: One,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn one() -> DifferentiableHidden<A> {
|
||||||
match self {
|
DifferentiableHidden::Scalar(Scalar::one())
|
||||||
Link::EndOfLink => f.write_str("<end>"),
|
|
||||||
Link::Link(LinkData::Addition(left, right)) => {
|
|
||||||
f.write_fmt(format_args!("({} + {})", left.as_ref(), right.as_ref()))
|
|
||||||
}
|
|
||||||
Link::Link(LinkData::Mul(left, right)) => {
|
|
||||||
f.write_fmt(format_args!("({} * {})", left.as_ref(), right.as_ref()))
|
|
||||||
}
|
|
||||||
Link::Link(LinkData::Exponent(arg)) => {
|
|
||||||
f.write_fmt(format_args!("exp({})", arg.as_ref()))
|
|
||||||
}
|
|
||||||
Link::Link(LinkData::Log(arg)) => f.write_fmt(format_args!("log({})", arg.as_ref())),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A> Link<A> {
|
impl<A> Clone for DifferentiableHidden<A>
|
||||||
fn invoke(self, d: &Scalar<A>, z: A, acc: &mut HashMap<Scalar<A>, A>)
|
|
||||||
where
|
|
||||||
A: Eq + Hash + AddAssign + Clone + Exp + Mul<Output = A> + Div<Output = A> + Zero + One,
|
|
||||||
{
|
|
||||||
match self {
|
|
||||||
Link::EndOfLink => match acc.entry(d.clone()) {
|
|
||||||
Entry::Occupied(mut o) => {
|
|
||||||
let entry = o.get_mut();
|
|
||||||
*entry += z;
|
|
||||||
}
|
|
||||||
Entry::Vacant(v) => {
|
|
||||||
v.insert(z);
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Link::Link(data) => {
|
|
||||||
match data {
|
|
||||||
LinkData::Addition(left, right) => {
|
|
||||||
// The `z` here reflects the fact that dx/dx = 1, so it's 1 * z.
|
|
||||||
left.as_ref().clone_link().invoke(&left, z.clone(), acc);
|
|
||||||
right.as_ref().clone_link().invoke(&right, z, acc);
|
|
||||||
}
|
|
||||||
LinkData::Exponent(arg) => {
|
|
||||||
// d/dx (e^x) = exp x, so exp z * z.
|
|
||||||
arg.as_ref().clone_link().invoke(
|
|
||||||
&arg,
|
|
||||||
z * arg.clone_real_part().exp(),
|
|
||||||
acc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
LinkData::Mul(left, right) => {
|
|
||||||
// d/dx(f g) = f dg/dx + g df/dx
|
|
||||||
left.as_ref().clone_link().invoke(
|
|
||||||
&left,
|
|
||||||
right.clone_real_part() * z.clone(),
|
|
||||||
acc,
|
|
||||||
);
|
|
||||||
right
|
|
||||||
.as_ref()
|
|
||||||
.clone_link()
|
|
||||||
.invoke(&right, left.clone_real_part() * z, acc);
|
|
||||||
}
|
|
||||||
LinkData::Log(arg) => {
|
|
||||||
// d/dx(log y) = 1/y dy/dx
|
|
||||||
arg.as_ref().clone_link().invoke(
|
|
||||||
&arg,
|
|
||||||
A::one() / arg.clone_real_part() * z,
|
|
||||||
acc,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Hash, PartialEq, Eq)]
|
|
||||||
pub enum Scalar<A> {
|
|
||||||
Number(A),
|
|
||||||
// The value, and the link.
|
|
||||||
Dual(A, Link<A>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<A> Add for Scalar<A>
|
|
||||||
where
|
where
|
||||||
A: Add<Output = A> + Clone,
|
A: Clone,
|
||||||
{
|
{
|
||||||
type Output = Scalar<A>;
|
fn clone(&self) -> Self {
|
||||||
|
|
||||||
fn add(self, rhs: Self) -> Self::Output {
|
|
||||||
Scalar::Dual(
|
|
||||||
self.clone_real_part() + rhs.clone_real_part(),
|
|
||||||
Link::Link(LinkData::Addition(Box::new(self), Box::new(rhs))),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<A> Mul for Scalar<A>
|
|
||||||
where
|
|
||||||
A: Mul<Output = A> + Clone,
|
|
||||||
{
|
|
||||||
type Output = Scalar<A>;
|
|
||||||
|
|
||||||
fn mul(self, rhs: Self) -> Self::Output {
|
|
||||||
Scalar::Dual(
|
|
||||||
self.clone_real_part() * rhs.clone_real_part(),
|
|
||||||
Link::Link(LinkData::Mul(Box::new(self), Box::new(rhs))),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<A> Scalar<A> {
|
|
||||||
pub fn real_part(&self) -> &A {
|
|
||||||
match self {
|
match self {
|
||||||
Scalar::Number(a) => a,
|
Self::Scalar(arg0) => Self::Scalar(arg0.clone()),
|
||||||
Scalar::Dual(a, _) => a,
|
Self::Vector(arg0) => Self::Vector(arg0.clone()),
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clone_real_part(&self) -> A
|
|
||||||
where
|
|
||||||
A: Clone,
|
|
||||||
{
|
|
||||||
match self {
|
|
||||||
Scalar::Number(a) => (*a).clone(),
|
|
||||||
Scalar::Dual(a, _) => (*a).clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn link(self) -> Link<A> {
|
|
||||||
match self {
|
|
||||||
Scalar::Dual(_, link) => link,
|
|
||||||
Scalar::Number(_) => Link::EndOfLink,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clone_link(&self) -> Link<A>
|
|
||||||
where
|
|
||||||
A: Clone,
|
|
||||||
{
|
|
||||||
match self {
|
|
||||||
Scalar::Dual(_, data) => data.clone(),
|
|
||||||
Scalar::Number(_) => Link::EndOfLink,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn truncate_dual(self) -> Scalar<A>
|
|
||||||
where
|
|
||||||
A: Clone,
|
|
||||||
{
|
|
||||||
Scalar::Dual(self.clone_real_part(), Link::EndOfLink)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<A> Display for Scalar<A>
|
|
||||||
where
|
|
||||||
A: Display,
|
|
||||||
{
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
Scalar::Number(n) => f.write_fmt(format_args!("{}", n)),
|
|
||||||
Scalar::Dual(n, link) => f.write_fmt(format_args!("{}, link: {}", n, link)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum Differentiable<A> {
|
enum DifferentiableHidden<A> {
|
||||||
Scalar(Scalar<A>),
|
Scalar(Scalar<A>),
|
||||||
Vector(Box<[Differentiable<A>]>),
|
Vector(Vec<DifferentiableHidden<A>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A> Display for Differentiable<A>
|
impl<A> Display for DifferentiableHidden<A>
|
||||||
where
|
where
|
||||||
A: Display,
|
A: Display,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)),
|
DifferentiableHidden::Scalar(s) => f.write_fmt(format_args!("{}", s)),
|
||||||
Differentiable::Vector(v) => {
|
DifferentiableHidden::Vector(v) => {
|
||||||
f.write_char('[')?;
|
f.write_char('[')?;
|
||||||
for v in v.iter() {
|
for v in v.iter() {
|
||||||
f.write_fmt(format_args!("{}", v))?;
|
f.write_fmt(format_args!("{}", v))?;
|
||||||
@@ -271,26 +70,70 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A> Differentiable<A> {
|
impl<A> DifferentiableHidden<A> {
|
||||||
pub fn map<B, F>(&self, f: &F) -> Differentiable<B>
|
fn map<B, F>(&self, f: &F) -> DifferentiableHidden<B>
|
||||||
where
|
where
|
||||||
F: Fn(Scalar<A>) -> Scalar<B>,
|
F: Fn(Scalar<A>) -> Scalar<B>,
|
||||||
A: Clone,
|
A: Clone,
|
||||||
{
|
{
|
||||||
match self {
|
match self {
|
||||||
Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())),
|
DifferentiableHidden::Scalar(a) => DifferentiableHidden::Scalar(f(a.clone())),
|
||||||
Differentiable::Vector(slice) => {
|
DifferentiableHidden::Vector(slice) => {
|
||||||
Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect())
|
DifferentiableHidden::Vector(slice.iter().map(|x| x.map(f)).collect())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn map2<B, C, F>(&self, other: &DifferentiableHidden<B>, f: &F) -> DifferentiableHidden<C>
|
||||||
|
where
|
||||||
|
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
|
||||||
|
A: Clone,
|
||||||
|
B: Clone,
|
||||||
|
{
|
||||||
|
match (self, other) {
|
||||||
|
(DifferentiableHidden::Scalar(a), DifferentiableHidden::Scalar(b)) => {
|
||||||
|
DifferentiableHidden::Scalar(f(a, b))
|
||||||
|
}
|
||||||
|
(DifferentiableHidden::Vector(slice_a), DifferentiableHidden::Vector(slice_b)) => {
|
||||||
|
DifferentiableHidden::Vector(
|
||||||
|
slice_a
|
||||||
|
.iter()
|
||||||
|
.zip(slice_b.iter())
|
||||||
|
.map(|(a, b)| a.map2(b, f))
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => panic!("Wrong shapes!"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn of_slice(input: &[A]) -> DifferentiableHidden<A>
|
||||||
|
where
|
||||||
|
A: Clone,
|
||||||
|
{
|
||||||
|
DifferentiableHidden::Vector(
|
||||||
|
input
|
||||||
|
.iter()
|
||||||
|
.map(|v| DifferentiableHidden::Scalar(Scalar::Number((*v).clone())))
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A> Differentiable<A>
|
impl<A> DifferentiableHidden<A>
|
||||||
where
|
where
|
||||||
A: Clone + Eq + Hash + AddAssign + Mul<Output = A> + Exp + Div<Output = A> + Zero + One,
|
A: Clone
|
||||||
|
+ Eq
|
||||||
|
+ Hash
|
||||||
|
+ AddAssign
|
||||||
|
+ Mul<Output = A>
|
||||||
|
+ Exp
|
||||||
|
+ Div<Output = A>
|
||||||
|
+ Zero
|
||||||
|
+ One
|
||||||
|
+ Neg<Output = A>,
|
||||||
{
|
{
|
||||||
fn accumulate_gradients_vec(v: &[Differentiable<A>], acc: &mut HashMap<Scalar<A>, A>) {
|
fn accumulate_gradients_vec(v: &[DifferentiableHidden<A>], acc: &mut HashMap<Scalar<A>, A>) {
|
||||||
for v in v.iter().rev() {
|
for v in v.iter().rev() {
|
||||||
v.accumulate_gradients(acc);
|
v.accumulate_gradients(acc);
|
||||||
}
|
}
|
||||||
@@ -298,15 +141,17 @@ where
|
|||||||
|
|
||||||
fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) {
|
fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) {
|
||||||
match self {
|
match self {
|
||||||
Differentiable::Scalar(y) => {
|
DifferentiableHidden::Scalar(y) => {
|
||||||
let k = y.clone_link();
|
let k = y.clone_link();
|
||||||
k.invoke(y, A::one(), acc);
|
k.invoke(y, A::one(), acc);
|
||||||
}
|
}
|
||||||
Differentiable::Vector(y) => Differentiable::accumulate_gradients_vec(y, acc),
|
DifferentiableHidden::Vector(y) => {
|
||||||
|
DifferentiableHidden::accumulate_gradients_vec(y, acc)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn grad_once(self, wrt: Differentiable<A>) -> Differentiable<A> {
|
fn grad_once(self, wrt: &DifferentiableHidden<A>) -> DifferentiableHidden<A> {
|
||||||
let mut acc = HashMap::new();
|
let mut acc = HashMap::new();
|
||||||
self.accumulate_gradients(&mut acc);
|
self.accumulate_gradients(&mut acc);
|
||||||
|
|
||||||
@@ -315,34 +160,133 @@ where
|
|||||||
Some(x) => Scalar::Number(x.clone()),
|
Some(x) => Scalar::Number(x.clone()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn grad<F>(f: F, theta: Differentiable<A>) -> Differentiable<A>
|
#[derive(Clone)]
|
||||||
|
pub struct Differentiable<A, const RANK: usize> {
|
||||||
|
contents: DifferentiableHidden<A>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A, const RANK: usize> Display for Differentiable<A, RANK>
|
||||||
|
where
|
||||||
|
A: Display,
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
Display::fmt(&self.contents, f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn of_scalar<A>(s: Scalar<A>) -> Differentiable<A, 0> {
|
||||||
|
Differentiable {
|
||||||
|
contents: DifferentiableHidden::Scalar(s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_scalar<A>(s: Differentiable<A, 0>) -> Scalar<A> {
|
||||||
|
match s.contents {
|
||||||
|
DifferentiableHidden::Scalar(s) => s,
|
||||||
|
DifferentiableHidden::Vector(_) => panic!("not a vector"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn of_slice<A>(input: &[A]) -> Differentiable<A, 1>
|
||||||
|
where
|
||||||
|
A: Clone,
|
||||||
|
{
|
||||||
|
Differentiable {
|
||||||
|
contents: DifferentiableHidden::of_slice(input),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A, const RANK: usize> Differentiable<A, RANK> {
|
||||||
|
pub fn of_vector(s: Vec<Differentiable<A, RANK>>) -> Differentiable<A, { RANK + 1 }> {
|
||||||
|
Differentiable {
|
||||||
|
contents: DifferentiableHidden::Vector(s.into_iter().map(|v| v.contents).collect()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn map<B, F>(s: Differentiable<A, RANK>, f: &F) -> Differentiable<B, RANK>
|
||||||
where
|
where
|
||||||
F: Fn(&Differentiable<A>) -> Differentiable<A>,
|
F: Fn(Scalar<A>) -> Scalar<B>,
|
||||||
|
A: Clone,
|
||||||
{
|
{
|
||||||
let wrt = theta.map(&Scalar::truncate_dual);
|
Differentiable {
|
||||||
let after_f = f(&wrt);
|
contents: DifferentiableHidden::map(&s.contents, f),
|
||||||
Differentiable::grad_once(after_f, wrt)
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn map2<B, C, F>(
|
||||||
|
self: &Differentiable<A, RANK>,
|
||||||
|
other: &Differentiable<B, RANK>,
|
||||||
|
f: &F,
|
||||||
|
) -> Differentiable<C, RANK>
|
||||||
|
where
|
||||||
|
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
|
||||||
|
A: Clone,
|
||||||
|
B: Clone,
|
||||||
|
{
|
||||||
|
Differentiable {
|
||||||
|
contents: DifferentiableHidden::map2(&self.contents, &other.contents, f),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_vector(s: Differentiable<A, { RANK + 1 }>) -> Vec<Differentiable<A, RANK>> {
|
||||||
|
match s.contents {
|
||||||
|
DifferentiableHidden::Scalar(_) => panic!("not a scalar"),
|
||||||
|
DifferentiableHidden::Vector(v) => v
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| Differentiable { contents: v })
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn grad<F>(f: F, theta: Differentiable<A, RANK>) -> Differentiable<A, RANK>
|
||||||
|
where
|
||||||
|
F: Fn(Differentiable<A, RANK>) -> Differentiable<A, RANK>,
|
||||||
|
A: Clone
|
||||||
|
+ Hash
|
||||||
|
+ AddAssign
|
||||||
|
+ Mul<Output = A>
|
||||||
|
+ Exp
|
||||||
|
+ Div<Output = A>
|
||||||
|
+ Zero
|
||||||
|
+ One
|
||||||
|
+ Neg<Output = A>
|
||||||
|
+ Eq,
|
||||||
|
{
|
||||||
|
let wrt = theta.contents.map(&Scalar::truncate_dual);
|
||||||
|
let after_f = f(Differentiable {
|
||||||
|
contents: wrt.clone(),
|
||||||
|
});
|
||||||
|
Differentiable {
|
||||||
|
contents: DifferentiableHidden::grad_once(after_f.contents, &wrt),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use ordered_float::NotNan;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
fn extract_scalar<'a, A>(d: &'a Differentiable<A>) -> &'a A {
|
fn extract_scalar<'a, A>(d: &'a DifferentiableHidden<A>) -> &'a A {
|
||||||
match d {
|
match d {
|
||||||
Differentiable::Scalar(a) => &(a.real_part()),
|
DifferentiableHidden::Scalar(a) => &(a.real_part()),
|
||||||
Differentiable::Vector(_) => panic!("not a scalar"),
|
DifferentiableHidden::Vector(_) => panic!("not a scalar"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_map() {
|
fn test_map() {
|
||||||
let v = Differentiable::Vector(
|
let v = DifferentiableHidden::Vector(
|
||||||
vec![
|
vec![
|
||||||
Differentiable::Scalar(Scalar::Number(NotNan::new(3.0).expect("3 is not NaN"))),
|
DifferentiableHidden::Scalar(Scalar::Number(
|
||||||
Differentiable::Scalar(Scalar::Number(NotNan::new(4.0).expect("4 is not NaN"))),
|
NotNan::new(3.0).expect("3 is not NaN"),
|
||||||
|
)),
|
||||||
|
DifferentiableHidden::Scalar(Scalar::Number(
|
||||||
|
NotNan::new(4.0).expect("4 is not NaN"),
|
||||||
|
)),
|
||||||
]
|
]
|
||||||
.into(),
|
.into(),
|
||||||
);
|
);
|
||||||
@@ -352,9 +296,8 @@ mod tests {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let v = match mapped {
|
let v = match mapped {
|
||||||
Differentiable::Scalar(_) => panic!("Not a scalar"),
|
DifferentiableHidden::Scalar(_) => panic!("Not a scalar"),
|
||||||
Differentiable::Vector(v) => v
|
DifferentiableHidden::Vector(v) => v
|
||||||
.as_ref()
|
|
||||||
.iter()
|
.iter()
|
||||||
.map(|d| extract_scalar(d).clone())
|
.map(|d| extract_scalar(d).clone())
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
|
@@ -1,3 +1,8 @@
|
|||||||
|
#![allow(incomplete_features)]
|
||||||
|
#![feature(generic_const_exprs)]
|
||||||
|
|
||||||
pub mod auto_diff;
|
pub mod auto_diff;
|
||||||
pub mod expr_syntax_tree;
|
pub mod expr_syntax_tree;
|
||||||
|
pub mod scalar;
|
||||||
pub mod tensor;
|
pub mod tensor;
|
||||||
|
pub mod traits;
|
||||||
|
251
little_learner/src/scalar.rs
Normal file
251
little_learner/src/scalar.rs
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
use crate::traits::{Exp, One, Zero};
|
||||||
|
use core::hash::Hash;
|
||||||
|
use std::{
|
||||||
|
collections::{hash_map::Entry, HashMap},
|
||||||
|
fmt::Display,
|
||||||
|
iter::Sum,
|
||||||
|
ops::{Add, AddAssign, Div, Mul, Neg, Sub},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||||
|
pub enum LinkData<A> {
|
||||||
|
Addition(Box<Scalar<A>>, Box<Scalar<A>>),
|
||||||
|
Neg(Box<Scalar<A>>),
|
||||||
|
Mul(Box<Scalar<A>>, Box<Scalar<A>>),
|
||||||
|
Exponent(Box<Scalar<A>>),
|
||||||
|
Log(Box<Scalar<A>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||||
|
pub enum Link<A> {
|
||||||
|
EndOfLink,
|
||||||
|
Link(LinkData<A>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Display for Link<A>
|
||||||
|
where
|
||||||
|
A: Display,
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Link::EndOfLink => f.write_str("<end>"),
|
||||||
|
Link::Link(LinkData::Addition(left, right)) => {
|
||||||
|
f.write_fmt(format_args!("({} + {})", left.as_ref(), right.as_ref()))
|
||||||
|
}
|
||||||
|
Link::Link(LinkData::Neg(arg)) => f.write_fmt(format_args!("(-{})", arg.as_ref())),
|
||||||
|
Link::Link(LinkData::Mul(left, right)) => {
|
||||||
|
f.write_fmt(format_args!("({} * {})", left.as_ref(), right.as_ref()))
|
||||||
|
}
|
||||||
|
Link::Link(LinkData::Exponent(arg)) => {
|
||||||
|
f.write_fmt(format_args!("exp({})", arg.as_ref()))
|
||||||
|
}
|
||||||
|
Link::Link(LinkData::Log(arg)) => f.write_fmt(format_args!("log({})", arg.as_ref())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Link<A> {
|
||||||
|
pub fn invoke(self, d: &Scalar<A>, z: A, acc: &mut HashMap<Scalar<A>, A>)
|
||||||
|
where
|
||||||
|
A: Eq
|
||||||
|
+ Hash
|
||||||
|
+ AddAssign
|
||||||
|
+ Clone
|
||||||
|
+ Exp
|
||||||
|
+ Mul<Output = A>
|
||||||
|
+ Div<Output = A>
|
||||||
|
+ Neg<Output = A>
|
||||||
|
+ Zero
|
||||||
|
+ One,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Link::EndOfLink => match acc.entry(d.clone()) {
|
||||||
|
Entry::Occupied(mut o) => {
|
||||||
|
let entry = o.get_mut();
|
||||||
|
*entry += z;
|
||||||
|
}
|
||||||
|
Entry::Vacant(v) => {
|
||||||
|
v.insert(z);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Link::Link(data) => {
|
||||||
|
match data {
|
||||||
|
LinkData::Addition(left, right) => {
|
||||||
|
// The `z` here reflects the fact that dx/dx = 1, so it's 1 * z.
|
||||||
|
left.as_ref().clone_link().invoke(&left, z.clone(), acc);
|
||||||
|
right.as_ref().clone_link().invoke(&right, z, acc);
|
||||||
|
}
|
||||||
|
LinkData::Exponent(arg) => {
|
||||||
|
// d/dx (e^x) = exp x, so exp z * z.
|
||||||
|
arg.as_ref().clone_link().invoke(
|
||||||
|
&arg,
|
||||||
|
z * arg.clone_real_part().exp(),
|
||||||
|
acc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
LinkData::Mul(left, right) => {
|
||||||
|
// d/dx(f g) = f dg/dx + g df/dx
|
||||||
|
left.as_ref().clone_link().invoke(
|
||||||
|
&left,
|
||||||
|
right.clone_real_part() * z.clone(),
|
||||||
|
acc,
|
||||||
|
);
|
||||||
|
right
|
||||||
|
.as_ref()
|
||||||
|
.clone_link()
|
||||||
|
.invoke(&right, left.clone_real_part() * z, acc);
|
||||||
|
}
|
||||||
|
LinkData::Log(arg) => {
|
||||||
|
// d/dx(log y) = 1/y dy/dx
|
||||||
|
arg.as_ref().clone_link().invoke(
|
||||||
|
&arg,
|
||||||
|
A::one() / arg.clone_real_part() * z,
|
||||||
|
acc,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
LinkData::Neg(arg) => {
|
||||||
|
// d/dx(-y) = - dy/dx
|
||||||
|
arg.as_ref().clone_link().invoke(&arg, -z, acc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Hash, PartialEq, Eq)]
|
||||||
|
pub enum Scalar<A> {
|
||||||
|
Number(A),
|
||||||
|
// The value, and the link.
|
||||||
|
Dual(A, Link<A>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Zero for Scalar<A>
|
||||||
|
where
|
||||||
|
A: Zero,
|
||||||
|
{
|
||||||
|
fn zero() -> Self {
|
||||||
|
Scalar::Number(A::zero())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Add for Scalar<A>
|
||||||
|
where
|
||||||
|
A: Add<Output = A> + Clone,
|
||||||
|
{
|
||||||
|
type Output = Scalar<A>;
|
||||||
|
|
||||||
|
fn add(self, rhs: Self) -> Self::Output {
|
||||||
|
Scalar::Dual(
|
||||||
|
self.clone_real_part() + rhs.clone_real_part(),
|
||||||
|
Link::Link(LinkData::Addition(Box::new(self), Box::new(rhs))),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Neg for Scalar<A>
|
||||||
|
where
|
||||||
|
A: Neg<Output = A> + Clone,
|
||||||
|
{
|
||||||
|
type Output = Scalar<A>;
|
||||||
|
|
||||||
|
fn neg(self) -> Self::Output {
|
||||||
|
Scalar::Dual(
|
||||||
|
-self.clone_real_part(),
|
||||||
|
Link::Link(LinkData::Neg(Box::new(self))),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Sub for Scalar<A>
|
||||||
|
where
|
||||||
|
A: Add<Output = A> + Neg<Output = A> + Clone,
|
||||||
|
{
|
||||||
|
type Output = Scalar<A>;
|
||||||
|
|
||||||
|
fn sub(self, rhs: Self) -> Self::Output {
|
||||||
|
self + (-rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Mul for Scalar<A>
|
||||||
|
where
|
||||||
|
A: Mul<Output = A> + Clone,
|
||||||
|
{
|
||||||
|
type Output = Scalar<A>;
|
||||||
|
|
||||||
|
fn mul(self, rhs: Self) -> Self::Output {
|
||||||
|
Scalar::Dual(
|
||||||
|
self.clone_real_part() * rhs.clone_real_part(),
|
||||||
|
Link::Link(LinkData::Mul(Box::new(self), Box::new(rhs))),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Sum for Scalar<A>
|
||||||
|
where
|
||||||
|
A: Zero + Add<Output = A> + Clone,
|
||||||
|
{
|
||||||
|
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
|
||||||
|
let mut answer = Zero::zero();
|
||||||
|
for i in iter {
|
||||||
|
answer = answer + i;
|
||||||
|
}
|
||||||
|
answer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Scalar<A> {
|
||||||
|
pub fn real_part(&self) -> &A {
|
||||||
|
match self {
|
||||||
|
Scalar::Number(a) => a,
|
||||||
|
Scalar::Dual(a, _) => a,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clone_real_part(&self) -> A
|
||||||
|
where
|
||||||
|
A: Clone,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Scalar::Number(a) => (*a).clone(),
|
||||||
|
Scalar::Dual(a, _) => (*a).clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn link(self) -> Link<A> {
|
||||||
|
match self {
|
||||||
|
Scalar::Dual(_, link) => link,
|
||||||
|
Scalar::Number(_) => Link::EndOfLink,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clone_link(&self) -> Link<A>
|
||||||
|
where
|
||||||
|
A: Clone,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Scalar::Dual(_, data) => data.clone(),
|
||||||
|
Scalar::Number(_) => Link::EndOfLink,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn truncate_dual(self) -> Scalar<A>
|
||||||
|
where
|
||||||
|
A: Clone,
|
||||||
|
{
|
||||||
|
Scalar::Dual(self.clone_real_part(), Link::EndOfLink)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Display for Scalar<A>
|
||||||
|
where
|
||||||
|
A: Display,
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
Scalar::Number(n) => f.write_fmt(format_args!("{}", n)),
|
||||||
|
Scalar::Dual(n, link) => f.write_fmt(format_args!("{}, link: {}", n, link)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
43
little_learner/src/traits.rs
Normal file
43
little_learner/src/traits.rs
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
use ordered_float::NotNan;
|
||||||
|
|
||||||
|
pub trait Exp {
|
||||||
|
fn exp(self) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Exp for NotNan<f64> {
|
||||||
|
fn exp(self) -> Self {
|
||||||
|
NotNan::new(f64::exp(self.into_inner())).expect("expected a non-NaN")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait Zero {
|
||||||
|
fn zero() -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait One {
|
||||||
|
fn one() -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Zero for f64 {
|
||||||
|
fn zero() -> Self {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl One for f64 {
|
||||||
|
fn one() -> Self {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Zero for NotNan<f64> {
|
||||||
|
fn zero() -> Self {
|
||||||
|
NotNan::new(0.0).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl One for NotNan<f64> {
|
||||||
|
fn one() -> Self {
|
||||||
|
NotNan::new(1.0).unwrap()
|
||||||
|
}
|
||||||
|
}
|
@@ -1,93 +1,90 @@
|
|||||||
use little_learner::auto_diff::{Differentiable, Scalar};
|
#![allow(incomplete_features)]
|
||||||
use little_learner::tensor;
|
#![feature(generic_const_exprs)]
|
||||||
use little_learner::tensor::{extension2, Extensible2};
|
|
||||||
|
mod with_tensor;
|
||||||
|
|
||||||
|
use little_learner::auto_diff::{of_scalar, of_slice, to_scalar, Differentiable};
|
||||||
|
use little_learner::scalar::Scalar;
|
||||||
|
use little_learner::traits::{One, Zero};
|
||||||
use ordered_float::NotNan;
|
use ordered_float::NotNan;
|
||||||
|
|
||||||
use std::iter::Sum;
|
use std::iter::Sum;
|
||||||
use std::ops::{Mul, Sub};
|
use std::ops::{Add, Mul, Neg};
|
||||||
|
|
||||||
type Point<A, const N: usize> = [A; N];
|
use crate::with_tensor::{l2_loss, predict_line};
|
||||||
|
|
||||||
type Parameters<A, const N: usize, const M: usize> = [Point<A, N>; M];
|
fn dot_2<A, const RANK: usize>(
|
||||||
|
x: &Differentiable<A, RANK>,
|
||||||
fn dot_points<A: Mul, const N: usize>(x: &Point<A, N>, y: &Point<A, N>) -> A
|
y: &Differentiable<A, RANK>,
|
||||||
|
) -> Differentiable<A, RANK>
|
||||||
where
|
where
|
||||||
A: Sum<<A as Mul>::Output> + Copy + Default + Mul<Output = A> + Extensible2<A>,
|
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Copy + Default,
|
||||||
{
|
{
|
||||||
extension2(x, y, |&x, &y| x * y).into_iter().sum()
|
Differentiable::map2(x, y, &|x, y| x.clone() * y.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dot<A, const N: usize, const M: usize>(x: &Point<A, N>, y: &Parameters<A, N, M>) -> Point<A, M>
|
fn squared_2<A, const RANK: usize>(x: &Differentiable<A, RANK>) -> Differentiable<A, RANK>
|
||||||
where
|
where
|
||||||
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Copy + Default + Extensible2<A>,
|
A: Mul<Output = A> + Copy + Default,
|
||||||
{
|
{
|
||||||
let mut result = [Default::default(); M];
|
Differentiable::map2(x, x, &|x, y| x.clone() * y.clone())
|
||||||
for (i, coord) in y.iter().map(|y| dot_points(x, y)).enumerate() {
|
|
||||||
result[i] = coord;
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sum<A, const N: usize>(x: &tensor!(A, N)) -> A
|
fn sum_2<A>(x: Differentiable<A, 1>) -> Scalar<A>
|
||||||
where
|
where
|
||||||
A: Sum<A> + Copy,
|
A: Sum<A> + Copy + Add<Output = A> + Zero,
|
||||||
{
|
{
|
||||||
A::sum(x.iter().cloned())
|
Differentiable::to_vector(x)
|
||||||
|
.into_iter()
|
||||||
|
.map(to_scalar)
|
||||||
|
.sum()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn squared<A, const N: usize>(x: &tensor!(A, N)) -> tensor!(A, N)
|
fn l2_norm_2<A>(prediction: &Differentiable<A, 1>, data: &Differentiable<A, 1>) -> Scalar<A>
|
||||||
where
|
where
|
||||||
A: Mul<Output = A> + Extensible2<A> + Copy + Default,
|
A: Sum<A> + Mul<Output = A> + Copy + Default + Neg<Output = A> + Add<Output = A> + Zero + Neg,
|
||||||
{
|
{
|
||||||
extension2(x, x, |&a, &b| (a * b))
|
let diff = Differentiable::map2(prediction, data, &|x, y| x.clone() - y.clone());
|
||||||
|
sum_2(squared_2(&diff))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn l2_norm<A, const N: usize>(prediction: &tensor!(A, N), data: &tensor!(A, N)) -> A
|
pub fn l2_loss_2<A, F, Params>(
|
||||||
where
|
|
||||||
A: Sum<A> + Mul<Output = A> + Extensible2<A> + Copy + Default + Sub<Output = A>,
|
|
||||||
{
|
|
||||||
let diff = extension2(prediction, data, |&x, &y| x - y);
|
|
||||||
sum(&squared(&diff))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn l2_loss<A, F, Params, const N: usize>(
|
|
||||||
target: F,
|
target: F,
|
||||||
data_xs: &tensor!(A, N),
|
data_xs: Differentiable<A, 1>,
|
||||||
data_ys: &tensor!(A, N),
|
data_ys: Differentiable<A, 1>,
|
||||||
params: &Params,
|
params: Params,
|
||||||
) -> A
|
) -> Scalar<A>
|
||||||
where
|
where
|
||||||
F: Fn(&tensor!(A, N), &Params) -> tensor!(A, N),
|
F: Fn(Differentiable<A, 1>, Params) -> Differentiable<A, 1>,
|
||||||
A: Sum<A> + Mul<Output = A> + Extensible2<A> + Copy + Default + Sub<Output = A>,
|
A: Sum<A> + Mul<Output = A> + Copy + Default + Neg<Output = A> + Add<Output = A> + Zero,
|
||||||
{
|
{
|
||||||
let pred_ys = target(data_xs, params);
|
let pred_ys = target(data_xs, params);
|
||||||
l2_norm(&pred_ys, data_ys)
|
l2_norm_2(&pred_ys, &data_ys)
|
||||||
}
|
}
|
||||||
|
|
||||||
trait One {
|
fn predict_line_2<A>(xs: Differentiable<A, 1>, theta: Differentiable<A, 1>) -> Differentiable<A, 1>
|
||||||
const ONE: Self;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl One for f64 {
|
|
||||||
const ONE: f64 = 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn predict_line<A, const N: usize>(xs: &tensor!(A, N), theta: &tensor!(A, 2)) -> tensor!(A, N)
|
|
||||||
where
|
where
|
||||||
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Copy + Default + Extensible2<A> + One,
|
A: Mul<Output = A> + Add<Output = A> + Sum<<A as Mul>::Output> + Copy + Default + One + Zero,
|
||||||
{
|
{
|
||||||
let mut result: tensor!(A, N) = [Default::default(); N];
|
let xs = Differentiable::to_vector(xs)
|
||||||
for (i, &x) in xs.iter().enumerate() {
|
.into_iter()
|
||||||
result[i] = dot(&[x, One::ONE], &[*theta])[0];
|
.map(|v| to_scalar(v));
|
||||||
|
let mut result = vec![];
|
||||||
|
for x in xs {
|
||||||
|
let left_arg = Differentiable::of_vector(vec![
|
||||||
|
of_scalar(x.clone()),
|
||||||
|
of_scalar(<Scalar<A> as One>::one()),
|
||||||
|
]);
|
||||||
|
let dotted = Differentiable::to_vector(dot_2(&left_arg, &theta));
|
||||||
|
result.push(dotted[0].clone());
|
||||||
}
|
}
|
||||||
result
|
Differentiable::of_vector(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn square<A>(x: &A) -> A
|
fn square<A>(x: &A) -> A
|
||||||
where
|
where
|
||||||
A: Mul<Output = A> + Clone + std::fmt::Display,
|
A: Mul<Output = A> + Clone,
|
||||||
{
|
{
|
||||||
println!("{}", x);
|
|
||||||
x.clone() * x.clone()
|
x.clone() * x.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,61 +97,16 @@ fn main() {
|
|||||||
);
|
);
|
||||||
println!("{:?}", loss);
|
println!("{:?}", loss);
|
||||||
|
|
||||||
let input_vec = Differentiable::Vector(Box::new([Differentiable::Scalar(Scalar::Number(
|
let loss = l2_loss_2(
|
||||||
NotNan::new(27.0).expect("not nan"),
|
predict_line_2,
|
||||||
))]));
|
of_slice(&[2.0, 1.0, 4.0, 3.0]),
|
||||||
|
of_slice(&[1.8, 1.2, 4.2, 3.3]),
|
||||||
|
of_slice(&[0.0099, 0.0]),
|
||||||
|
);
|
||||||
|
println!("{}", loss);
|
||||||
|
|
||||||
let grad = Differentiable::grad(|x| x.map(&|x| square(&x)), input_vec);
|
let input_vec = of_slice(&[NotNan::new(27.0).expect("not nan")]);
|
||||||
|
|
||||||
|
let grad = Differentiable::grad(|x| Differentiable::map(x, &|x| square(&x)), input_vec);
|
||||||
println!("{}", grad);
|
println!("{}", grad);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use little_learner::tensor::extension1;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_extension() {
|
|
||||||
let x: tensor!(u8, 1) = [2];
|
|
||||||
assert_eq!(extension1(&x, &7, |x, y| x + y), [9]);
|
|
||||||
let y: tensor!(u8, 1) = [7];
|
|
||||||
assert_eq!(extension2(&x, &y, |x, y| x + y), [9]);
|
|
||||||
|
|
||||||
let x: tensor!(u8, 3) = [5, 6, 7];
|
|
||||||
assert_eq!(extension1(&x, &2, |x, y| x + y), [7, 8, 9]);
|
|
||||||
let y: tensor!(u8, 3) = [2, 0, 1];
|
|
||||||
assert_eq!(extension2(&x, &y, |x, y| x + y), [7, 6, 8]);
|
|
||||||
|
|
||||||
let x: tensor!(u8, 2, 3) = [[4, 6, 7], [2, 0, 1]];
|
|
||||||
assert_eq!(extension1(&x, &2, |x, y| x + y), [[6, 8, 9], [4, 2, 3]]);
|
|
||||||
let y: tensor!(u8, 2, 3) = [[1, 2, 2], [6, 3, 1]];
|
|
||||||
assert_eq!(extension2(&x, &y, |x, y| x + y), [[5, 8, 9], [8, 3, 2]]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_l2_norm() {
|
|
||||||
assert_eq!(
|
|
||||||
l2_norm(&[4.0, -3.0, 0.0, -4.0, 3.0], &[0.0, 0.0, 0.0, 0.0, 0.0]),
|
|
||||||
50.0
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_l2_loss() {
|
|
||||||
let loss = l2_loss(
|
|
||||||
predict_line,
|
|
||||||
&[2.0, 1.0, 4.0, 3.0],
|
|
||||||
&[1.8, 1.2, 4.2, 3.3],
|
|
||||||
&[0.0, 0.0],
|
|
||||||
);
|
|
||||||
assert_eq!(loss, 33.21);
|
|
||||||
|
|
||||||
let loss = l2_loss(
|
|
||||||
predict_line,
|
|
||||||
&[2.0, 1.0, 4.0, 3.0],
|
|
||||||
&[1.8, 1.2, 4.2, 3.3],
|
|
||||||
&[0.0099, 0.0],
|
|
||||||
);
|
|
||||||
assert_eq!((100.0 * loss).round() / 100.0, 32.59);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
126
little_learner_app/src/with_tensor.rs
Normal file
126
little_learner_app/src/with_tensor.rs
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
use std::iter::Sum;
|
||||||
|
use std::ops::{Mul, Sub};
|
||||||
|
|
||||||
|
use little_learner::tensor;
|
||||||
|
use little_learner::tensor::{extension2, Extensible2};
|
||||||
|
use little_learner::traits::One;
|
||||||
|
|
||||||
|
type Point<A, const N: usize> = [A; N];
|
||||||
|
|
||||||
|
type Parameters<A, const N: usize, const M: usize> = [Point<A, N>; M];
|
||||||
|
|
||||||
|
fn dot_points<A: Mul, const N: usize>(x: &Point<A, N>, y: &Point<A, N>) -> A
|
||||||
|
where
|
||||||
|
A: Sum<<A as Mul>::Output> + Copy + Default + Mul<Output = A> + Extensible2<A>,
|
||||||
|
{
|
||||||
|
extension2(x, y, |&x, &y| x * y).into_iter().sum()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dot<A, const N: usize, const M: usize>(x: &Point<A, N>, y: &Parameters<A, N, M>) -> Point<A, M>
|
||||||
|
where
|
||||||
|
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Copy + Default + Extensible2<A>,
|
||||||
|
{
|
||||||
|
let mut result = [Default::default(); M];
|
||||||
|
for (i, coord) in y.iter().map(|y| dot_points(x, y)).enumerate() {
|
||||||
|
result[i] = coord;
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum<A, const N: usize>(x: &tensor!(A, N)) -> A
|
||||||
|
where
|
||||||
|
A: Sum<A> + Copy,
|
||||||
|
{
|
||||||
|
A::sum(x.iter().cloned())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn squared<A, const N: usize>(x: &tensor!(A, N)) -> tensor!(A, N)
|
||||||
|
where
|
||||||
|
A: Mul<Output = A> + Extensible2<A> + Copy + Default,
|
||||||
|
{
|
||||||
|
extension2(x, x, |&a, &b| (a * b))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn l2_norm<A, const N: usize>(prediction: &tensor!(A, N), data: &tensor!(A, N)) -> A
|
||||||
|
where
|
||||||
|
A: Sum<A> + Mul<Output = A> + Extensible2<A> + Copy + Default + Sub<Output = A>,
|
||||||
|
{
|
||||||
|
let diff = extension2(prediction, data, |&x, &y| x - y);
|
||||||
|
sum(&squared(&diff))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn l2_loss<A, F, Params, const N: usize>(
|
||||||
|
target: F,
|
||||||
|
data_xs: &tensor!(A, N),
|
||||||
|
data_ys: &tensor!(A, N),
|
||||||
|
params: &Params,
|
||||||
|
) -> A
|
||||||
|
where
|
||||||
|
F: Fn(&tensor!(A, N), &Params) -> tensor!(A, N),
|
||||||
|
A: Sum<A> + Mul<Output = A> + Extensible2<A> + Copy + Default + Sub<Output = A>,
|
||||||
|
{
|
||||||
|
let pred_ys = target(data_xs, params);
|
||||||
|
l2_norm(&pred_ys, data_ys)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn predict_line<A, const N: usize>(xs: &tensor!(A, N), theta: &tensor!(A, 2)) -> tensor!(A, N)
|
||||||
|
where
|
||||||
|
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Copy + Default + Extensible2<A> + One,
|
||||||
|
{
|
||||||
|
let mut result: tensor!(A, N) = [Default::default(); N];
|
||||||
|
for (i, &x) in xs.iter().enumerate() {
|
||||||
|
result[i] = dot(&[x, One::one()], &[*theta])[0];
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use little_learner::tensor::extension1;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_extension() {
|
||||||
|
let x: tensor!(u8, 1) = [2];
|
||||||
|
assert_eq!(extension1(&x, &7, |x, y| x + y), [9]);
|
||||||
|
let y: tensor!(u8, 1) = [7];
|
||||||
|
assert_eq!(extension2(&x, &y, |x, y| x + y), [9]);
|
||||||
|
|
||||||
|
let x: tensor!(u8, 3) = [5, 6, 7];
|
||||||
|
assert_eq!(extension1(&x, &2, |x, y| x + y), [7, 8, 9]);
|
||||||
|
let y: tensor!(u8, 3) = [2, 0, 1];
|
||||||
|
assert_eq!(extension2(&x, &y, |x, y| x + y), [7, 6, 8]);
|
||||||
|
|
||||||
|
let x: tensor!(u8, 2, 3) = [[4, 6, 7], [2, 0, 1]];
|
||||||
|
assert_eq!(extension1(&x, &2, |x, y| x + y), [[6, 8, 9], [4, 2, 3]]);
|
||||||
|
let y: tensor!(u8, 2, 3) = [[1, 2, 2], [6, 3, 1]];
|
||||||
|
assert_eq!(extension2(&x, &y, |x, y| x + y), [[5, 8, 9], [8, 3, 2]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_l2_norm() {
|
||||||
|
assert_eq!(
|
||||||
|
l2_norm(&[4.0, -3.0, 0.0, -4.0, 3.0], &[0.0, 0.0, 0.0, 0.0, 0.0]),
|
||||||
|
50.0
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_l2_loss() {
|
||||||
|
let loss = l2_loss(
|
||||||
|
predict_line,
|
||||||
|
&[2.0, 1.0, 4.0, 3.0],
|
||||||
|
&[1.8, 1.2, 4.2, 3.3],
|
||||||
|
&[0.0, 0.0],
|
||||||
|
);
|
||||||
|
assert_eq!(loss, 33.21);
|
||||||
|
|
||||||
|
let loss = l2_loss(
|
||||||
|
predict_line,
|
||||||
|
&[2.0, 1.0, 4.0, 3.0],
|
||||||
|
&[1.8, 1.2, 4.2, 3.3],
|
||||||
|
&[0.0099, 0.0],
|
||||||
|
);
|
||||||
|
assert_eq!((100.0 * loss).round() / 100.0, 32.59);
|
||||||
|
}
|
||||||
|
}
|
1
rust-toolchain
Normal file
1
rust-toolchain
Normal file
@@ -0,0 +1 @@
|
|||||||
|
nightly
|
Reference in New Issue
Block a user