Describe the network (#31)

This commit is contained in:
Patrick Stevens
2023-06-17 23:03:32 +01:00
committed by GitHub
parent bdb5d8e192
commit f873e5ca3d
5 changed files with 33 additions and 15 deletions

View File

@@ -129,7 +129,7 @@ jobs:
},
{
"name": "Run Clippy",
"run": "nix develop --command cargo -- clippy -- -D warnings"
"run": "nix develop --command cargo -- clippy -- -D warnings -W clippy::must_use_candidate"
}
]
}

View File

@@ -250,12 +250,12 @@ impl<A, Tag> DifferentiableContents<A, Tag> {
}
(
DifferentiableContents::Vector(v1, rank1),
DifferentiableContents::Vector(v2, _rank2),
DifferentiableContents::Vector(v2, rank2),
) => {
assert_eq!(
v1.len(),
v2.len(),
"Must map two vectors of the same length, got {rank1} and {_rank2}"
"Must map two vectors of the same length, got {rank1} and {rank2}"
);
assert_ne!(
v1.len(),

View File

@@ -1,4 +1,4 @@
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
use crate::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
use crate::ext::relu;
use crate::traits::NumLike;
@@ -7,30 +7,39 @@ pub struct Block<F, const N: usize> {
ranks: [usize; N],
}
pub fn compose<'a, A, T, B, C, F, G, const N: usize, const M: usize>(
/// Does the second argument first, so compose(b1, b2) performs b2 on its input, and then b1.
pub fn compose<'a, 'c, 'd, A, T, B, C, F, G, const N: usize, const M: usize>(
b1: Block<F, N>,
b2: Block<G, M>,
j: usize,
) -> Block<impl FnOnce(A, &'a [T]) -> C, { N + M }>
) -> Block<impl FnOnce(&'a A, &'d [T]) -> C, { N + M }>
where
F: FnOnce(A, &'a [T]) -> B,
G: FnOnce(B, &'a [T]) -> C,
T: 'a,
F: FnOnce(&'a A, &'d [T]) -> B,
G: for<'b> FnOnce(&'b B, &'d [T]) -> C,
A: 'a,
T: 'd,
{
let mut ranks = [0usize; N + M];
ranks.copy_from_slice(&b1.ranks[..N]);
ranks[N..(M + N)].copy_from_slice(&b2.ranks[..M]);
ranks[..N].copy_from_slice(&b1.ranks);
ranks[N..(M + N)].copy_from_slice(&b2.ranks);
Block {
f: move |t, theta| (b2.f)((b1.f)(t, theta), &theta[j..]),
f: move |t, theta| {
let intermediate = (b1.f)(t, theta);
(b2.f)(&intermediate, &theta[j..])
},
ranks,
}
}
pub fn dense<'a, 'b, A, Tag>(
#[must_use]
pub fn dense<'b, A, Tag>(
input_len: usize,
neuron_count: usize,
) -> Block<
impl FnOnce(&'a RankedDifferentiableTagged<A, Tag, 1>, &'b [Differentiable<A>]) -> Differentiable<A>,
impl for<'a> FnOnce(
&'a RankedDifferentiableTagged<A, Tag, 1>,
&'b [Differentiable<A>],
) -> RankedDifferentiable<A, 1>,
2,
>
where
@@ -38,12 +47,16 @@ where
A: NumLike + PartialOrd + Default,
{
Block {
f: |t, theta: &'b [Differentiable<A>]| -> Differentiable<A> {
f: for<'a> |t: &'a RankedDifferentiableTagged<A, Tag, 1>,
theta: &'b [Differentiable<A>]|
-> RankedDifferentiable<A, 1> {
relu(
t,
&(theta[0].clone().attach_rank().unwrap()),
&(theta[1].clone().attach_rank().unwrap()),
)
.attach_rank()
.unwrap()
},
ranks: [input_len, neuron_count],
}

View File

@@ -1,6 +1,7 @@
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
#![feature(array_methods)]
#![feature(closure_lifetime_binder)]
pub mod auto_diff;
pub mod block;

View File

@@ -3,6 +3,8 @@
use crate::rms_example::rms_example;
use little_learner::auto_diff::RankedDifferentiable;
use little_learner::block;
use ordered_float::NotNan;
mod iris;
mod rms_example;
@@ -20,4 +22,6 @@ fn main() {
}
let _xs = RankedDifferentiable::of_vector(xs);
let _ys = RankedDifferentiable::of_vector(ys);
let _network = block::compose(block::dense::<NotNan<f64>, ()>(6, 3), block::dense(4, 6), 2);
}