Describe the network (#31)
This commit is contained in:
2
.github/workflows/rust.yml
vendored
2
.github/workflows/rust.yml
vendored
@@ -129,7 +129,7 @@ jobs:
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Run Clippy",
|
"name": "Run Clippy",
|
||||||
"run": "nix develop --command cargo -- clippy -- -D warnings"
|
"run": "nix develop --command cargo -- clippy -- -D warnings -W clippy::must_use_candidate"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -250,12 +250,12 @@ impl<A, Tag> DifferentiableContents<A, Tag> {
|
|||||||
}
|
}
|
||||||
(
|
(
|
||||||
DifferentiableContents::Vector(v1, rank1),
|
DifferentiableContents::Vector(v1, rank1),
|
||||||
DifferentiableContents::Vector(v2, _rank2),
|
DifferentiableContents::Vector(v2, rank2),
|
||||||
) => {
|
) => {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
v1.len(),
|
v1.len(),
|
||||||
v2.len(),
|
v2.len(),
|
||||||
"Must map two vectors of the same length, got {rank1} and {_rank2}"
|
"Must map two vectors of the same length, got {rank1} and {rank2}"
|
||||||
);
|
);
|
||||||
assert_ne!(
|
assert_ne!(
|
||||||
v1.len(),
|
v1.len(),
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use crate::auto_diff::{Differentiable, RankedDifferentiableTagged};
|
use crate::auto_diff::{Differentiable, RankedDifferentiable, RankedDifferentiableTagged};
|
||||||
use crate::ext::relu;
|
use crate::ext::relu;
|
||||||
use crate::traits::NumLike;
|
use crate::traits::NumLike;
|
||||||
|
|
||||||
@@ -7,30 +7,39 @@ pub struct Block<F, const N: usize> {
|
|||||||
ranks: [usize; N],
|
ranks: [usize; N],
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn compose<'a, A, T, B, C, F, G, const N: usize, const M: usize>(
|
/// Does the second argument first, so compose(b1, b2) performs b2 on its input, and then b1.
|
||||||
|
pub fn compose<'a, 'c, 'd, A, T, B, C, F, G, const N: usize, const M: usize>(
|
||||||
b1: Block<F, N>,
|
b1: Block<F, N>,
|
||||||
b2: Block<G, M>,
|
b2: Block<G, M>,
|
||||||
j: usize,
|
j: usize,
|
||||||
) -> Block<impl FnOnce(A, &'a [T]) -> C, { N + M }>
|
) -> Block<impl FnOnce(&'a A, &'d [T]) -> C, { N + M }>
|
||||||
where
|
where
|
||||||
F: FnOnce(A, &'a [T]) -> B,
|
F: FnOnce(&'a A, &'d [T]) -> B,
|
||||||
G: FnOnce(B, &'a [T]) -> C,
|
G: for<'b> FnOnce(&'b B, &'d [T]) -> C,
|
||||||
T: 'a,
|
A: 'a,
|
||||||
|
T: 'd,
|
||||||
{
|
{
|
||||||
let mut ranks = [0usize; N + M];
|
let mut ranks = [0usize; N + M];
|
||||||
ranks.copy_from_slice(&b1.ranks[..N]);
|
ranks[..N].copy_from_slice(&b1.ranks);
|
||||||
ranks[N..(M + N)].copy_from_slice(&b2.ranks[..M]);
|
ranks[N..(M + N)].copy_from_slice(&b2.ranks);
|
||||||
Block {
|
Block {
|
||||||
f: move |t, theta| (b2.f)((b1.f)(t, theta), &theta[j..]),
|
f: move |t, theta| {
|
||||||
|
let intermediate = (b1.f)(t, theta);
|
||||||
|
(b2.f)(&intermediate, &theta[j..])
|
||||||
|
},
|
||||||
ranks,
|
ranks,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dense<'a, 'b, A, Tag>(
|
#[must_use]
|
||||||
|
pub fn dense<'b, A, Tag>(
|
||||||
input_len: usize,
|
input_len: usize,
|
||||||
neuron_count: usize,
|
neuron_count: usize,
|
||||||
) -> Block<
|
) -> Block<
|
||||||
impl FnOnce(&'a RankedDifferentiableTagged<A, Tag, 1>, &'b [Differentiable<A>]) -> Differentiable<A>,
|
impl for<'a> FnOnce(
|
||||||
|
&'a RankedDifferentiableTagged<A, Tag, 1>,
|
||||||
|
&'b [Differentiable<A>],
|
||||||
|
) -> RankedDifferentiable<A, 1>,
|
||||||
2,
|
2,
|
||||||
>
|
>
|
||||||
where
|
where
|
||||||
@@ -38,12 +47,16 @@ where
|
|||||||
A: NumLike + PartialOrd + Default,
|
A: NumLike + PartialOrd + Default,
|
||||||
{
|
{
|
||||||
Block {
|
Block {
|
||||||
f: |t, theta: &'b [Differentiable<A>]| -> Differentiable<A> {
|
f: for<'a> |t: &'a RankedDifferentiableTagged<A, Tag, 1>,
|
||||||
|
theta: &'b [Differentiable<A>]|
|
||||||
|
-> RankedDifferentiable<A, 1> {
|
||||||
relu(
|
relu(
|
||||||
t,
|
t,
|
||||||
&(theta[0].clone().attach_rank().unwrap()),
|
&(theta[0].clone().attach_rank().unwrap()),
|
||||||
&(theta[1].clone().attach_rank().unwrap()),
|
&(theta[1].clone().attach_rank().unwrap()),
|
||||||
)
|
)
|
||||||
|
.attach_rank()
|
||||||
|
.unwrap()
|
||||||
},
|
},
|
||||||
ranks: [input_len, neuron_count],
|
ranks: [input_len, neuron_count],
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#![allow(incomplete_features)]
|
#![allow(incomplete_features)]
|
||||||
#![feature(generic_const_exprs)]
|
#![feature(generic_const_exprs)]
|
||||||
#![feature(array_methods)]
|
#![feature(array_methods)]
|
||||||
|
#![feature(closure_lifetime_binder)]
|
||||||
|
|
||||||
pub mod auto_diff;
|
pub mod auto_diff;
|
||||||
pub mod block;
|
pub mod block;
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
|
|
||||||
use crate::rms_example::rms_example;
|
use crate::rms_example::rms_example;
|
||||||
use little_learner::auto_diff::RankedDifferentiable;
|
use little_learner::auto_diff::RankedDifferentiable;
|
||||||
|
use little_learner::block;
|
||||||
|
use ordered_float::NotNan;
|
||||||
|
|
||||||
mod iris;
|
mod iris;
|
||||||
mod rms_example;
|
mod rms_example;
|
||||||
@@ -20,4 +22,6 @@ fn main() {
|
|||||||
}
|
}
|
||||||
let _xs = RankedDifferentiable::of_vector(xs);
|
let _xs = RankedDifferentiable::of_vector(xs);
|
||||||
let _ys = RankedDifferentiable::of_vector(ys);
|
let _ys = RankedDifferentiable::of_vector(ys);
|
||||||
|
|
||||||
|
let _network = block::compose(block::dense::<NotNan<f64>, ()>(6, 3), block::dense(4, 6), 2);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user