This commit is contained in:
Patrick Stevens
2023-03-23 19:17:53 +00:00
committed by GitHub
parent 0a199b0065
commit c3bfeb0762

View File

@@ -1,20 +1,120 @@
use std::iter::Sum;
use std::ops::Mul;
use std::ops::{Mul, Sub};
type Point<A, const N: usize> = [A; N];
type Parameters<A, const N: usize, const M: usize> = [Point<A, N>; M];
#[macro_export]
macro_rules! tensor {
($x:ty , $i: expr) => {[$x; $i]};
($x:ty , $i: expr, $($is:expr),+) => {[tensor!($x, $($is),+); $i]};
}
pub trait Extensible1<A> {
fn apply<F>(&self, other: &A, op: &F) -> Self
where
F: Fn(&A, &A) -> A;
}
pub trait Extensible2<A> {
fn apply<F>(&self, other: &Self, op: &F) -> Self
where
F: Fn(&A, &A) -> A;
}
impl<A, T, const N: usize> Extensible1<A> for [T; N]
where
T: Extensible1<A> + Copy + Default,
{
fn apply<F>(&self, other: &A, op: &F) -> Self
where
F: Fn(&A, &A) -> A,
{
let mut result = [Default::default(); N];
for (i, coord) in self.iter().enumerate() {
result[i] = T::apply(coord, other, op);
}
result
}
}
impl<A, T, const N: usize> Extensible2<A> for [T; N]
where
T: Extensible2<A> + Copy + Default,
{
fn apply<F>(&self, other: &Self, op: &F) -> Self
where
F: Fn(&A, &A) -> A,
{
let mut result = [Default::default(); N];
for (i, coord) in self.iter().enumerate() {
result[i] = T::apply(coord, &other[i], op);
}
result
}
}
#[macro_export]
macro_rules! extensible1 {
($x: ty) => {
impl Extensible1<$x> for $x {
fn apply<F>(&self, other: &$x, op: &F) -> Self
where
F: Fn(&Self, &Self) -> Self,
{
op(self, other)
}
}
};
}
#[macro_export]
macro_rules! extensible2 {
($x: ty) => {
impl Extensible2<$x> for $x {
fn apply<F>(&self, other: &Self, op: &F) -> Self
where
F: Fn(&Self, &Self) -> Self,
{
op(self, other)
}
}
};
}
extensible1!(u8);
extensible1!(f64);
extensible2!(u8);
extensible2!(f64);
pub fn extension1<T, A, F>(t1: &T, t2: &A, op: F) -> T
where
T: Extensible1<A>,
F: Fn(&A, &A) -> A,
{
t1.apply::<F>(t2, &op)
}
pub fn extension2<T, A, F>(t1: &T, t2: &T, op: F) -> T
where
T: Extensible2<A>,
F: Fn(&A, &A) -> A,
{
t1.apply::<F>(t2, &op)
}
fn dot_points<A: Mul, const N: usize>(x: &Point<A, N>, y: &Point<A, N>) -> A
where
A: Sum<<A as Mul>::Output> + Copy,
A: Sum<<A as Mul>::Output> + Copy + Default + Mul<Output = A> + Extensible2<A>,
{
x.iter().zip(y).map(|(&x, &y)| x * y).sum()
extension2(x, y, |&x, &y| x * y).into_iter().sum()
}
fn dot<A, const N: usize, const M: usize>(x: &Point<A, N>, y: &Parameters<A, N, M>) -> Point<A, M>
where
A: Mul<A> + Sum<<A as Mul>::Output> + Copy + Default,
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Copy + Default + Extensible2<A>,
{
let mut result = [Default::default(); M];
for (i, coord) in y.iter().map(|y| dot_points(x, y)).enumerate() {
@@ -23,30 +123,69 @@ where
result
}
fn line<A, const N: usize>(x: &Point<A, N>, theta: &Parameters<A, N, 1>) -> Point<A, 1>
fn sum<A, const N: usize>(x: &tensor!(A, N)) -> A
where
A: Mul<A> + Sum<<A as Mul>::Output> + Copy + Default,
A: Sum<A> + Copy,
{
dot(x, theta)
A::sum(x.iter().cloned())
}
//fn data_set() -> ([f64; 4], [f64; 4]) {
// ([2.0, 1.0, 4.0, 3.0], [1.8, 1.2, 4.2, 3.3])
//}
fn linear_params_2d<A>(m: A, c: A) -> Parameters<A, 2, 1> {
[[c, m]]
fn squared<A, const N: usize>(x: &tensor!(A, N)) -> tensor!(A, N)
where
A: Mul<Output = A> + Extensible2<A> + Copy + Default,
{
extension2(x, x, |&a, &b| (a * b))
}
#[macro_export]
macro_rules! tensor {
($x:ty , $i: expr) => {[$x; $i]};
($x:ty , $i: expr, $($is:expr),+) => {[tensor!($x, $($is),+); $i]};
fn l2_norm<A, const N: usize>(prediction: &tensor!(A, N), data: &tensor!(A, N)) -> A
where
A: Sum<A> + Mul<Output = A> + Extensible2<A> + Copy + Default + Sub<Output = A>,
{
let diff = extension2(prediction, data, |&x, &y| x - y);
sum(&squared(&diff))
}
pub fn l2_loss<A, F, Params, const N: usize>(
target: F,
data_xs: &tensor!(A, N),
data_ys: &tensor!(A, N),
params: &Params,
) -> A
where
F: Fn(&tensor!(A, N), &Params) -> tensor!(A, N),
A: Sum<A> + Mul<Output = A> + Extensible2<A> + Copy + Default + Sub<Output = A>,
{
let pred_ys = target(data_xs, params);
l2_norm(&pred_ys, data_ys)
}
trait One {
const ONE: Self;
}
impl One for f64 {
const ONE: f64 = 1.0;
}
fn predict_line<A, const N: usize>(xs: &tensor!(A, N), theta: &tensor!(A, 2)) -> tensor!(A, N)
where
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Copy + Default + Extensible2<A> + One,
{
let mut result: tensor!(A, N) = [Default::default(); N];
for (i, &x) in xs.iter().enumerate() {
result[i] = dot(&[x, One::ONE], &[*theta])[0];
}
result
}
fn main() {
let y = line(&[1.0, 7.3], &linear_params_2d(3.0, 1.0));
println!("{:?}", y);
let loss = l2_loss(
predict_line,
&[2.0, 1.0, 4.0, 3.0],
&[1.8, 1.2, 4.2, 3.3],
&[0.0099, 0.0],
);
println!("{:?}", loss);
}
#[cfg(test)]
@@ -54,7 +193,52 @@ mod tests {
use super::*;
#[test]
fn test_tensor() {
fn test_tensor_type() {
let _: tensor!(f64, 1, 2, 3) = [[[1.0, 3.0, 6.0], [-1.3, -30.0, -0.0]]];
}
#[test]
fn test_extension() {
let x: tensor!(u8, 1) = [2];
assert_eq!(extension1(&x, &7, |x, y| x + y), [9]);
let y: tensor!(u8, 1) = [7];
assert_eq!(extension2(&x, &y, |x, y| x + y), [9]);
let x: tensor!(u8, 3) = [5, 6, 7];
assert_eq!(extension1(&x, &2, |x, y| x + y), [7, 8, 9]);
let y: tensor!(u8, 3) = [2, 0, 1];
assert_eq!(extension2(&x, &y, |x, y| x + y), [7, 6, 8]);
let x: tensor!(u8, 2, 3) = [[4, 6, 7], [2, 0, 1]];
assert_eq!(extension1(&x, &2, |x, y| x + y), [[6, 8, 9], [4, 2, 3]]);
let y: tensor!(u8, 2, 3) = [[1, 2, 2], [6, 3, 1]];
assert_eq!(extension2(&x, &y, |x, y| x + y), [[5, 8, 9], [8, 3, 2]]);
}
#[test]
fn test_l2_norm() {
assert_eq!(
l2_norm(&[4.0, -3.0, 0.0, -4.0, 3.0], &[0.0, 0.0, 0.0, 0.0, 0.0]),
50.0
)
}
#[test]
fn test_l2_loss() {
let loss = l2_loss(
predict_line,
&[2.0, 1.0, 4.0, 3.0],
&[1.8, 1.2, 4.2, 3.3],
&[0.0, 0.0],
);
assert_eq!(loss, 33.21);
let loss = l2_loss(
predict_line,
&[2.0, 1.0, 4.0, 3.0],
&[1.8, 1.2, 4.2, 3.3],
&[0.0099, 0.0],
);
assert_eq!((100.0 * loss).round() / 100.0, 32.59);
}
}