Describe the network (#31)

This commit is contained in:
Patrick Stevens
2023-06-17 23:03:32 +01:00
committed by GitHub
parent bdb5d8e192
commit f873e5ca3d
5 changed files with 33 additions and 15 deletions

View File

@@ -129,7 +129,7 @@ jobs:
}, },
{ {
"name": "Run Clippy", "name": "Run Clippy",
"run": "nix develop --command cargo -- clippy -- -D warnings" "run": "nix develop --command cargo -- clippy -- -D warnings -W clippy::must_use_candidate"
} }
] ]
} }

View File

@@ -250,12 +250,12 @@ impl<A, Tag> DifferentiableContents<A, Tag> {
} }
( (
DifferentiableContents::Vector(v1, rank1), DifferentiableContents::Vector(v1, rank1),
DifferentiableContents::Vector(v2, _rank2), DifferentiableContents::Vector(v2, rank2),
) => { ) => {
assert_eq!( assert_eq!(
v1.len(), v1.len(),
v2.len(), v2.len(),
"Must map two vectors of the same length, got {rank1} and {_rank2}" "Must map two vectors of the same length, got {rank1} and {rank2}"
); );
assert_ne!( assert_ne!(
v1.len(), v1.len(),

View File

@@ -1,4 +1,4 @@
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged}; use crate::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
use crate::ext::relu; use crate::ext::relu;
use crate::traits::NumLike; use crate::traits::NumLike;
@@ -7,30 +7,39 @@ pub struct Block<F, const N: usize> {
ranks: [usize; N], ranks: [usize; N],
} }
pub fn compose<'a, A, T, B, C, F, G, const N: usize, const M: usize>( /// Does the second argument first, so compose(b1, b2) performs b2 on its input, and then b1.
pub fn compose<'a, 'c, 'd, A, T, B, C, F, G, const N: usize, const M: usize>(
b1: Block<F, N>, b1: Block<F, N>,
b2: Block<G, M>, b2: Block<G, M>,
j: usize, j: usize,
) -> Block<impl FnOnce(A, &'a [T]) -> C, { N + M }> ) -> Block<impl FnOnce(&'a A, &'d [T]) -> C, { N + M }>
where where
F: FnOnce(A, &'a [T]) -> B, F: FnOnce(&'a A, &'d [T]) -> B,
G: FnOnce(B, &'a [T]) -> C, G: for<'b> FnOnce(&'b B, &'d [T]) -> C,
T: 'a, A: 'a,
T: 'd,
{ {
let mut ranks = [0usize; N + M]; let mut ranks = [0usize; N + M];
ranks.copy_from_slice(&b1.ranks[..N]); ranks[..N].copy_from_slice(&b1.ranks);
ranks[N..(M + N)].copy_from_slice(&b2.ranks[..M]); ranks[N..(M + N)].copy_from_slice(&b2.ranks);
Block { Block {
f: move |t, theta| (b2.f)((b1.f)(t, theta), &theta[j..]), f: move |t, theta| {
let intermediate = (b1.f)(t, theta);
(b2.f)(&intermediate, &theta[j..])
},
ranks, ranks,
} }
} }
pub fn dense<'a, 'b, A, Tag>( #[must_use]
pub fn dense<'b, A, Tag>(
input_len: usize, input_len: usize,
neuron_count: usize, neuron_count: usize,
) -> Block< ) -> Block<
impl FnOnce(&'a RankedDifferentiableTagged<A, Tag, 1>, &'b [Differentiable<A>]) -> Differentiable<A>, impl for<'a> FnOnce(
&'a RankedDifferentiableTagged<A, Tag, 1>,
&'b [Differentiable<A>],
) -> RankedDifferentiable<A, 1>,
2, 2,
> >
where where
@@ -38,12 +47,16 @@ where
A: NumLike + PartialOrd + Default, A: NumLike + PartialOrd + Default,
{ {
Block { Block {
f: |t, theta: &'b [Differentiable<A>]| -> Differentiable<A> { f: for<'a> |t: &'a RankedDifferentiableTagged<A, Tag, 1>,
theta: &'b [Differentiable<A>]|
-> RankedDifferentiable<A, 1> {
relu( relu(
t, t,
&(theta[0].clone().attach_rank().unwrap()), &(theta[0].clone().attach_rank().unwrap()),
&(theta[1].clone().attach_rank().unwrap()), &(theta[1].clone().attach_rank().unwrap()),
) )
.attach_rank()
.unwrap()
}, },
ranks: [input_len, neuron_count], ranks: [input_len, neuron_count],
} }

View File

@@ -1,6 +1,7 @@
#![allow(incomplete_features)] #![allow(incomplete_features)]
#![feature(generic_const_exprs)] #![feature(generic_const_exprs)]
#![feature(array_methods)] #![feature(array_methods)]
#![feature(closure_lifetime_binder)]
pub mod auto_diff; pub mod auto_diff;
pub mod block; pub mod block;

View File

@@ -3,6 +3,8 @@
use crate::rms_example::rms_example; use crate::rms_example::rms_example;
use little_learner::auto_diff::RankedDifferentiable; use little_learner::auto_diff::RankedDifferentiable;
use little_learner::block;
use ordered_float::NotNan;
mod iris; mod iris;
mod rms_example; mod rms_example;
@@ -20,4 +22,6 @@ fn main() {
} }
let _xs = RankedDifferentiable::of_vector(xs); let _xs = RankedDifferentiable::of_vector(xs);
let _ys = RankedDifferentiable::of_vector(ys); let _ys = RankedDifferentiable::of_vector(ys);
let _network = block::compose(block::dense::<NotNan<f64>, ()>(6, 3), block::dense(4, 6), 2);
} }