Describe the network (#31)
This commit is contained in:
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -129,7 +129,7 @@ jobs:
|
||||
},
|
||||
{
|
||||
"name": "Run Clippy",
|
||||
"run": "nix develop --command cargo -- clippy -- -D warnings"
|
||||
"run": "nix develop --command cargo -- clippy -- -D warnings -W clippy::must_use_candidate"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@@ -250,12 +250,12 @@ impl<A, Tag> DifferentiableContents<A, Tag> {
|
||||
}
|
||||
(
|
||||
DifferentiableContents::Vector(v1, rank1),
|
||||
DifferentiableContents::Vector(v2, _rank2),
|
||||
DifferentiableContents::Vector(v2, rank2),
|
||||
) => {
|
||||
assert_eq!(
|
||||
v1.len(),
|
||||
v2.len(),
|
||||
"Must map two vectors of the same length, got {rank1} and {_rank2}"
|
||||
"Must map two vectors of the same length, got {rank1} and {rank2}"
|
||||
);
|
||||
assert_ne!(
|
||||
v1.len(),
|
||||
|
@@ -1,4 +1,4 @@
|
||||
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
|
||||
use crate::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
|
||||
use crate::ext::relu;
|
||||
use crate::traits::NumLike;
|
||||
|
||||
@@ -7,30 +7,39 @@ pub struct Block<F, const N: usize> {
|
||||
ranks: [usize; N],
|
||||
}
|
||||
|
||||
pub fn compose<'a, A, T, B, C, F, G, const N: usize, const M: usize>(
|
||||
/// Does the second argument first, so compose(b1, b2) performs b2 on its input, and then b1.
|
||||
pub fn compose<'a, 'c, 'd, A, T, B, C, F, G, const N: usize, const M: usize>(
|
||||
b1: Block<F, N>,
|
||||
b2: Block<G, M>,
|
||||
j: usize,
|
||||
) -> Block<impl FnOnce(A, &'a [T]) -> C, { N + M }>
|
||||
) -> Block<impl FnOnce(&'a A, &'d [T]) -> C, { N + M }>
|
||||
where
|
||||
F: FnOnce(A, &'a [T]) -> B,
|
||||
G: FnOnce(B, &'a [T]) -> C,
|
||||
T: 'a,
|
||||
F: FnOnce(&'a A, &'d [T]) -> B,
|
||||
G: for<'b> FnOnce(&'b B, &'d [T]) -> C,
|
||||
A: 'a,
|
||||
T: 'd,
|
||||
{
|
||||
let mut ranks = [0usize; N + M];
|
||||
ranks.copy_from_slice(&b1.ranks[..N]);
|
||||
ranks[N..(M + N)].copy_from_slice(&b2.ranks[..M]);
|
||||
ranks[..N].copy_from_slice(&b1.ranks);
|
||||
ranks[N..(M + N)].copy_from_slice(&b2.ranks);
|
||||
Block {
|
||||
f: move |t, theta| (b2.f)((b1.f)(t, theta), &theta[j..]),
|
||||
f: move |t, theta| {
|
||||
let intermediate = (b1.f)(t, theta);
|
||||
(b2.f)(&intermediate, &theta[j..])
|
||||
},
|
||||
ranks,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dense<'a, 'b, A, Tag>(
|
||||
#[must_use]
|
||||
pub fn dense<'b, A, Tag>(
|
||||
input_len: usize,
|
||||
neuron_count: usize,
|
||||
) -> Block<
|
||||
impl FnOnce(&'a RankedDifferentiableTagged<A, Tag, 1>, &'b [Differentiable<A>]) -> Differentiable<A>,
|
||||
impl for<'a> FnOnce(
|
||||
&'a RankedDifferentiableTagged<A, Tag, 1>,
|
||||
&'b [Differentiable<A>],
|
||||
) -> RankedDifferentiable<A, 1>,
|
||||
2,
|
||||
>
|
||||
where
|
||||
@@ -38,12 +47,16 @@ where
|
||||
A: NumLike + PartialOrd + Default,
|
||||
{
|
||||
Block {
|
||||
f: |t, theta: &'b [Differentiable<A>]| -> Differentiable<A> {
|
||||
f: for<'a> |t: &'a RankedDifferentiableTagged<A, Tag, 1>,
|
||||
theta: &'b [Differentiable<A>]|
|
||||
-> RankedDifferentiable<A, 1> {
|
||||
relu(
|
||||
t,
|
||||
&(theta[0].clone().attach_rank().unwrap()),
|
||||
&(theta[1].clone().attach_rank().unwrap()),
|
||||
)
|
||||
.attach_rank()
|
||||
.unwrap()
|
||||
},
|
||||
ranks: [input_len, neuron_count],
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
#![allow(incomplete_features)]
|
||||
#![feature(generic_const_exprs)]
|
||||
#![feature(array_methods)]
|
||||
#![feature(closure_lifetime_binder)]
|
||||
|
||||
pub mod auto_diff;
|
||||
pub mod block;
|
||||
|
@@ -3,6 +3,8 @@
|
||||
|
||||
use crate::rms_example::rms_example;
|
||||
use little_learner::auto_diff::RankedDifferentiable;
|
||||
use little_learner::block;
|
||||
use ordered_float::NotNan;
|
||||
|
||||
mod iris;
|
||||
mod rms_example;
|
||||
@@ -20,4 +22,6 @@ fn main() {
|
||||
}
|
||||
let _xs = RankedDifferentiable::of_vector(xs);
|
||||
let _ys = RankedDifferentiable::of_vector(ys);
|
||||
|
||||
let _network = block::compose(block::dense::<NotNan<f64>, ()>(6, 3), block::dense(4, 6), 2);
|
||||
}
|
||||
|
Reference in New Issue
Block a user