Velocity descent (#16)
This commit is contained in:
@@ -7,13 +7,14 @@ use std::{
|
||||
ops::{AddAssign, Div, Mul, Neg},
|
||||
};
|
||||
|
||||
impl<A> Zero for Differentiable<A>
|
||||
impl<A, Tag> Zero for DifferentiableTagged<A, Tag>
|
||||
where
|
||||
A: Zero,
|
||||
Tag: Zero,
|
||||
{
|
||||
fn zero() -> Differentiable<A> {
|
||||
Differentiable {
|
||||
contents: DifferentiableContents::Scalar(Scalar::Number(A::zero(), None)),
|
||||
fn zero() -> DifferentiableTagged<A, Tag> {
|
||||
DifferentiableTagged {
|
||||
contents: DifferentiableContents::Scalar(Scalar::Number(A::zero(), None), Tag::zero()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -27,59 +28,62 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> One for Differentiable<A>
|
||||
impl<A, Tag> One for DifferentiableTagged<A, Tag>
|
||||
where
|
||||
A: One,
|
||||
Tag: Zero,
|
||||
{
|
||||
fn one() -> Differentiable<A> {
|
||||
Differentiable {
|
||||
contents: DifferentiableContents::Scalar(Scalar::one()),
|
||||
fn one() -> DifferentiableTagged<A, Tag> {
|
||||
DifferentiableTagged {
|
||||
contents: DifferentiableContents::Scalar(Scalar::one(), Tag::zero()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Clone for DifferentiableContents<A>
|
||||
impl<A, Tag> Clone for DifferentiableContents<A, Tag>
|
||||
where
|
||||
A: Clone,
|
||||
Tag: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
match self {
|
||||
Self::Scalar(arg0) => Self::Scalar(arg0.clone()),
|
||||
Self::Scalar(arg0, tag) => Self::Scalar(arg0.clone(), tag.clone()),
|
||||
Self::Vector(arg0, rank) => Self::Vector(arg0.clone(), *rank),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Clone for Differentiable<A>
|
||||
impl<A, Tag> Clone for DifferentiableTagged<A, Tag>
|
||||
where
|
||||
A: Clone,
|
||||
Tag: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Differentiable {
|
||||
DifferentiableTagged {
|
||||
contents: self.contents.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum DifferentiableContents<A> {
|
||||
Scalar(Scalar<A>),
|
||||
enum DifferentiableContents<A, Tag> {
|
||||
Scalar(Scalar<A>, Tag),
|
||||
// Contains the rank.
|
||||
Vector(Vec<Differentiable<A>>, usize),
|
||||
Vector(Vec<DifferentiableTagged<A, Tag>>, usize),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Differentiable<A> {
|
||||
contents: DifferentiableContents<A>,
|
||||
pub struct DifferentiableTagged<A, Tag> {
|
||||
contents: DifferentiableContents<A, Tag>,
|
||||
}
|
||||
|
||||
impl<A> Display for DifferentiableContents<A>
|
||||
impl<A, Tag> Display for DifferentiableContents<A, Tag>
|
||||
where
|
||||
A: Display,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(s) => f.write_fmt(format_args!("{}", s)),
|
||||
DifferentiableContents::Scalar(s, _) => f.write_fmt(format_args!("{}", s)),
|
||||
DifferentiableContents::Vector(v, _rank) => {
|
||||
f.write_char('[')?;
|
||||
for v in v.iter() {
|
||||
@@ -92,7 +96,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Display for Differentiable<A>
|
||||
impl<A, Tag> Display for DifferentiableTagged<A, Tag>
|
||||
where
|
||||
A: Display,
|
||||
{
|
||||
@@ -101,33 +105,58 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> DifferentiableContents<A> {
|
||||
fn map<B, F>(&self, f: &mut F) -> DifferentiableContents<B>
|
||||
pub type Differentiable<A> = DifferentiableTagged<A, ()>;
|
||||
pub type RankedDifferentiable<A, const RANK: usize> = RankedDifferentiableTagged<A, (), RANK>;
|
||||
|
||||
impl<A, Tag> DifferentiableContents<A, Tag> {
|
||||
fn map<B, F>(&self, f: &mut F) -> DifferentiableContents<B, Tag>
|
||||
where
|
||||
F: FnMut(Scalar<A>) -> Scalar<B>,
|
||||
A: Clone,
|
||||
Tag: Clone,
|
||||
{
|
||||
match self {
|
||||
DifferentiableContents::Scalar(a) => DifferentiableContents::Scalar(f(a.clone())),
|
||||
DifferentiableContents::Scalar(a, tag) => {
|
||||
DifferentiableContents::Scalar(f(a.clone()), (*tag).clone())
|
||||
}
|
||||
DifferentiableContents::Vector(slice, rank) => {
|
||||
DifferentiableContents::Vector(slice.iter().map(|x| x.map(f)).collect(), *rank)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn map2<B, C, F>(
|
||||
&self,
|
||||
other: &DifferentiableContents<B>,
|
||||
f: &mut F,
|
||||
) -> DifferentiableContents<C>
|
||||
fn map_tag<Tag2, F>(&self, f: &mut F) -> DifferentiableContents<A, Tag2>
|
||||
where
|
||||
F: FnMut(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
|
||||
F: FnMut(Tag) -> Tag2,
|
||||
A: Clone,
|
||||
Tag: Clone,
|
||||
{
|
||||
match self {
|
||||
DifferentiableContents::Scalar(a, tag) => {
|
||||
DifferentiableContents::Scalar((*a).clone(), f((*tag).clone()))
|
||||
}
|
||||
DifferentiableContents::Vector(slice, rank) => {
|
||||
DifferentiableContents::Vector(slice.iter().map(|x| x.map_tag(f)).collect(), *rank)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn map2<B, C, Tag2, Tag3, F>(
|
||||
&self,
|
||||
other: &DifferentiableContents<B, Tag2>,
|
||||
f: &mut F,
|
||||
) -> DifferentiableContents<C, Tag3>
|
||||
where
|
||||
F: FnMut(&Scalar<A>, Tag, &Scalar<B>, Tag2) -> (Scalar<C>, Tag3),
|
||||
A: Clone,
|
||||
B: Clone,
|
||||
Tag: Clone,
|
||||
Tag2: Clone,
|
||||
{
|
||||
match (self, other) {
|
||||
(DifferentiableContents::Scalar(a), DifferentiableContents::Scalar(b)) => {
|
||||
DifferentiableContents::Scalar(f(a, b))
|
||||
(DifferentiableContents::Scalar(a, tag1), DifferentiableContents::Scalar(b, tag2)) => {
|
||||
let (scalar, tag) = f(a, tag1.clone(), b, tag2.clone());
|
||||
DifferentiableContents::Scalar(scalar, tag)
|
||||
}
|
||||
(
|
||||
DifferentiableContents::Vector(slice_a, rank_a),
|
||||
@@ -140,7 +169,7 @@ impl<A> DifferentiableContents<A> {
|
||||
slice_a
|
||||
.iter()
|
||||
.zip(slice_b.iter())
|
||||
.map(|(a, b)| a.map2(b, f))
|
||||
.map(|(a, b)| a.map2_tagged(b, f))
|
||||
.collect(),
|
||||
*rank_a,
|
||||
)
|
||||
@@ -149,16 +178,20 @@ impl<A> DifferentiableContents<A> {
|
||||
}
|
||||
}
|
||||
|
||||
fn of_slice<'a, T, I>(input: I) -> DifferentiableContents<T>
|
||||
fn of_slice<'a, T, I>(tag: Tag, input: I) -> DifferentiableContents<T, Tag>
|
||||
where
|
||||
T: Clone + 'a,
|
||||
Tag: Clone,
|
||||
I: IntoIterator<Item = &'a T>,
|
||||
{
|
||||
DifferentiableContents::Vector(
|
||||
input
|
||||
.into_iter()
|
||||
.map(|v| Differentiable {
|
||||
contents: DifferentiableContents::Scalar(Scalar::Number(v.clone(), None)),
|
||||
.map(|v| DifferentiableTagged {
|
||||
contents: DifferentiableContents::Scalar(
|
||||
Scalar::Number(v.clone(), None),
|
||||
tag.clone(),
|
||||
),
|
||||
})
|
||||
.collect(),
|
||||
1,
|
||||
@@ -167,87 +200,122 @@ impl<A> DifferentiableContents<A> {
|
||||
|
||||
fn rank(&self) -> usize {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(_) => 0,
|
||||
DifferentiableContents::Scalar(_, _) => 0,
|
||||
DifferentiableContents::Vector(_, rank) => *rank,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Differentiable<A> {
|
||||
pub fn map<B, F>(&self, f: &mut F) -> Differentiable<B>
|
||||
impl<A, Tag> DifferentiableTagged<A, Tag> {
|
||||
pub fn map<B, F>(&self, f: &mut F) -> DifferentiableTagged<B, Tag>
|
||||
where
|
||||
A: Clone,
|
||||
Tag: Clone,
|
||||
F: FnMut(Scalar<A>) -> Scalar<B>,
|
||||
{
|
||||
Differentiable {
|
||||
DifferentiableTagged {
|
||||
contents: self.contents.map(f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_tag<Tag2, F>(&self, f: &mut F) -> DifferentiableTagged<A, Tag2>
|
||||
where
|
||||
F: FnMut(Tag) -> Tag2,
|
||||
A: Clone,
|
||||
Tag: Clone,
|
||||
{
|
||||
DifferentiableTagged {
|
||||
contents: self.contents.map_tag(f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map2_tagged<B, C, Tag2, Tag3, F>(
|
||||
&self,
|
||||
other: &DifferentiableTagged<B, Tag2>,
|
||||
f: &mut F,
|
||||
) -> DifferentiableTagged<C, Tag3>
|
||||
where
|
||||
F: FnMut(&Scalar<A>, Tag, &Scalar<B>, Tag2) -> (Scalar<C>, Tag3),
|
||||
A: Clone,
|
||||
B: Clone,
|
||||
Tag2: Clone,
|
||||
Tag: Clone,
|
||||
{
|
||||
DifferentiableTagged {
|
||||
contents: self.contents.map2(&other.contents, f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn attach_rank<const RANK: usize>(
|
||||
self: DifferentiableTagged<A, Tag>,
|
||||
) -> Option<RankedDifferentiableTagged<A, Tag, RANK>> {
|
||||
if self.contents.rank() == RANK {
|
||||
Some(RankedDifferentiableTagged { contents: self })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn of_scalar_tagged(s: Scalar<A>, tag: Tag) -> DifferentiableTagged<A, Tag> {
|
||||
DifferentiableTagged {
|
||||
contents: DifferentiableContents::Scalar(s, tag),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Differentiable<A> {
|
||||
pub fn map2<B, C, F>(&self, other: &Differentiable<B>, f: &mut F) -> Differentiable<C>
|
||||
where
|
||||
F: FnMut(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
|
||||
A: Clone,
|
||||
B: Clone,
|
||||
{
|
||||
Differentiable {
|
||||
contents: self.contents.map2(&other.contents, f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn attach_rank<const RANK: usize>(
|
||||
self: Differentiable<A>,
|
||||
) -> Option<RankedDifferentiable<A, RANK>> {
|
||||
if self.contents.rank() == RANK {
|
||||
Some(RankedDifferentiable { contents: self })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
pub fn of_scalar(s: Scalar<A>) -> Differentiable<A> {
|
||||
Differentiable {
|
||||
contents: DifferentiableContents::Scalar(s),
|
||||
}
|
||||
DifferentiableTagged::map2_tagged(self, other, &mut |a, (), b, ()| (f(a, b), ()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> DifferentiableContents<A> {
|
||||
impl<A> Differentiable<A> {
|
||||
pub fn of_scalar(s: Scalar<A>) -> Differentiable<A> {
|
||||
DifferentiableTagged::of_scalar_tagged(s, ())
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, Tag> DifferentiableContents<A, Tag> {
|
||||
fn into_scalar(self) -> Scalar<A> {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(s) => s,
|
||||
DifferentiableContents::Scalar(s, _) => s,
|
||||
DifferentiableContents::Vector(_, _) => panic!("not a scalar"),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_vector(self) -> Vec<Differentiable<A>> {
|
||||
fn into_vector(self) -> Vec<DifferentiableTagged<A, Tag>> {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(_) => panic!("not a vector"),
|
||||
DifferentiableContents::Scalar(_, _) => panic!("not a vector"),
|
||||
DifferentiableContents::Vector(v, _) => v,
|
||||
}
|
||||
}
|
||||
|
||||
fn borrow_scalar(&self) -> &Scalar<A> {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(s) => s,
|
||||
DifferentiableContents::Scalar(s, _) => s,
|
||||
DifferentiableContents::Vector(_, _) => panic!("not a scalar"),
|
||||
}
|
||||
}
|
||||
|
||||
fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
|
||||
fn borrow_vector(&self) -> &Vec<DifferentiableTagged<A, Tag>> {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(_) => panic!("not a vector"),
|
||||
DifferentiableContents::Scalar(_, _) => panic!("not a vector"),
|
||||
DifferentiableContents::Vector(v, _) => v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> Differentiable<A> {
|
||||
impl<A, Tag> DifferentiableTagged<A, Tag> {
|
||||
pub fn into_scalar(self) -> Scalar<A> {
|
||||
self.contents.into_scalar()
|
||||
}
|
||||
|
||||
pub fn into_vector(self) -> Vec<Differentiable<A>> {
|
||||
pub fn into_vector(self) -> Vec<DifferentiableTagged<A, Tag>> {
|
||||
self.contents.into_vector()
|
||||
}
|
||||
|
||||
@@ -255,26 +323,27 @@ impl<A> Differentiable<A> {
|
||||
self.contents.borrow_scalar()
|
||||
}
|
||||
|
||||
pub fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
|
||||
pub fn borrow_vector(&self) -> &Vec<DifferentiableTagged<A, Tag>> {
|
||||
self.contents.borrow_vector()
|
||||
}
|
||||
|
||||
fn of_slice<'a, T>(input: T) -> Differentiable<A>
|
||||
fn of_slice<'a, T>(input: T, tag: Tag) -> DifferentiableTagged<A, Tag>
|
||||
where
|
||||
A: Clone + 'a,
|
||||
Tag: Clone,
|
||||
T: IntoIterator<Item = &'a A>,
|
||||
{
|
||||
Differentiable {
|
||||
contents: DifferentiableContents::<A>::of_slice(input),
|
||||
DifferentiableTagged {
|
||||
contents: DifferentiableContents::<A, Tag>::of_slice(tag, input),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn of_vec(input: Vec<Differentiable<A>>) -> Differentiable<A> {
|
||||
pub fn of_vec(input: Vec<DifferentiableTagged<A, Tag>>) -> DifferentiableTagged<A, Tag> {
|
||||
if input.is_empty() {
|
||||
panic!("Can't make an empty tensor");
|
||||
}
|
||||
let rank = input[0].rank();
|
||||
Differentiable {
|
||||
DifferentiableTagged {
|
||||
contents: DifferentiableContents::Vector(input, 1 + rank),
|
||||
}
|
||||
}
|
||||
@@ -284,7 +353,7 @@ impl<A> Differentiable<A> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> DifferentiableContents<A>
|
||||
impl<A, Tag> DifferentiableContents<A, Tag>
|
||||
where
|
||||
A: Clone
|
||||
+ Eq
|
||||
@@ -297,7 +366,10 @@ where
|
||||
+ One
|
||||
+ Neg<Output = A>,
|
||||
{
|
||||
fn accumulate_gradients_vec(v: &[Differentiable<A>], acc: &mut HashMap<Scalar<A>, A>) {
|
||||
fn accumulate_gradients_vec(
|
||||
v: &[DifferentiableTagged<A, Tag>],
|
||||
acc: &mut HashMap<Scalar<A>, A>,
|
||||
) {
|
||||
for v in v.iter().rev() {
|
||||
v.contents.accumulate_gradients(acc);
|
||||
}
|
||||
@@ -305,7 +377,7 @@ where
|
||||
|
||||
fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) {
|
||||
match self {
|
||||
DifferentiableContents::Scalar(y) => {
|
||||
DifferentiableContents::Scalar(y, _) => {
|
||||
let k = y.clone_link();
|
||||
k.invoke(y, A::one(), acc);
|
||||
}
|
||||
@@ -317,8 +389,11 @@ where
|
||||
|
||||
fn grad_once<const PARAM_NUM: usize>(
|
||||
self,
|
||||
wrt: [Differentiable<A>; PARAM_NUM],
|
||||
) -> [Differentiable<A>; PARAM_NUM] {
|
||||
wrt: [DifferentiableTagged<A, Tag>; PARAM_NUM],
|
||||
) -> [DifferentiableTagged<A, Tag>; PARAM_NUM]
|
||||
where
|
||||
Tag: Clone,
|
||||
{
|
||||
let mut acc = HashMap::new();
|
||||
self.accumulate_gradients(&mut acc);
|
||||
|
||||
@@ -332,11 +407,11 @@ where
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RankedDifferentiable<A, const RANK: usize> {
|
||||
contents: Differentiable<A>,
|
||||
pub struct RankedDifferentiableTagged<A, Tag, const RANK: usize> {
|
||||
contents: DifferentiableTagged<A, Tag>,
|
||||
}
|
||||
|
||||
impl<A, const RANK: usize> Display for RankedDifferentiable<A, RANK>
|
||||
impl<A, Tag, const RANK: usize> Display for RankedDifferentiableTagged<A, Tag, RANK>
|
||||
where
|
||||
A: Display,
|
||||
{
|
||||
@@ -345,14 +420,35 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> RankedDifferentiable<A, 0> {
|
||||
impl<A, Tag> RankedDifferentiableTagged<A, Tag, 0> {
|
||||
pub fn to_scalar(self) -> Scalar<A> {
|
||||
self.contents.contents.into_scalar()
|
||||
}
|
||||
|
||||
pub fn of_scalar(s: Scalar<A>) -> RankedDifferentiable<A, 0> {
|
||||
RankedDifferentiable {
|
||||
contents: Differentiable::of_scalar(s),
|
||||
pub fn of_scalar_tagged(s: Scalar<A>, tag: Tag) -> RankedDifferentiableTagged<A, Tag, 0> {
|
||||
RankedDifferentiableTagged {
|
||||
contents: DifferentiableTagged::of_scalar_tagged(s, tag),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> RankedDifferentiable<A, 0> {
|
||||
pub fn of_scalar(s: Scalar<A>) -> Self {
|
||||
RankedDifferentiableTagged::of_scalar_tagged(s, ())
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, Tag> RankedDifferentiableTagged<A, Tag, 1> {
|
||||
pub fn of_slice_tagged<'a, T>(input: T, tag: Tag) -> RankedDifferentiableTagged<A, Tag, 1>
|
||||
where
|
||||
A: Clone + 'a,
|
||||
Tag: Clone,
|
||||
T: IntoIterator<Item = &'a A>,
|
||||
{
|
||||
RankedDifferentiableTagged {
|
||||
contents: DifferentiableTagged {
|
||||
contents: DifferentiableContents::<A, Tag>::of_slice(tag, input),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -363,10 +459,26 @@ impl<A> RankedDifferentiable<A, 1> {
|
||||
A: Clone + 'a,
|
||||
T: IntoIterator<Item = &'a A>,
|
||||
{
|
||||
RankedDifferentiable {
|
||||
contents: Differentiable {
|
||||
contents: DifferentiableContents::<A>::of_slice(input),
|
||||
},
|
||||
RankedDifferentiableTagged::of_slice_tagged(input, ())
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, Tag> RankedDifferentiableTagged<A, Tag, 2> {
|
||||
pub fn of_slice_2_tagged<T, const N: usize>(
|
||||
input: &[T],
|
||||
tag: Tag,
|
||||
) -> RankedDifferentiableTagged<A, Tag, 2>
|
||||
where
|
||||
A: Clone,
|
||||
T: AsRef<[A]>,
|
||||
Tag: Clone,
|
||||
{
|
||||
let v = input
|
||||
.iter()
|
||||
.map(|x| DifferentiableTagged::of_slice(x.as_ref(), tag.clone()))
|
||||
.collect::<Vec<_>>();
|
||||
RankedDifferentiableTagged {
|
||||
contents: DifferentiableTagged::of_vec(v),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -377,33 +489,84 @@ impl<A> RankedDifferentiable<A, 2> {
|
||||
A: Clone,
|
||||
T: AsRef<[A]>,
|
||||
{
|
||||
let v = input
|
||||
.iter()
|
||||
.map(|x| Differentiable::of_slice(x.as_ref()))
|
||||
.collect::<Vec<_>>();
|
||||
RankedDifferentiable {
|
||||
contents: Differentiable::of_vec(v),
|
||||
}
|
||||
RankedDifferentiableTagged::of_slice_2_tagged::<_, N>(input, ())
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
|
||||
pub fn to_unranked(self) -> Differentiable<A> {
|
||||
impl<A, Tag, const RANK: usize> RankedDifferentiableTagged<A, Tag, RANK> {
|
||||
pub fn to_unranked(self) -> DifferentiableTagged<A, Tag> {
|
||||
self.contents
|
||||
}
|
||||
|
||||
pub fn to_unranked_borrow(&self) -> &Differentiable<A> {
|
||||
pub fn to_unranked_borrow(&self) -> &DifferentiableTagged<A, Tag> {
|
||||
&self.contents
|
||||
}
|
||||
|
||||
pub fn of_vector(
|
||||
s: Vec<RankedDifferentiable<A, RANK>>,
|
||||
) -> RankedDifferentiable<A, { RANK + 1 }> {
|
||||
RankedDifferentiable {
|
||||
contents: Differentiable::of_vec(s.into_iter().map(|v| v.contents).collect()),
|
||||
s: Vec<RankedDifferentiableTagged<A, Tag, RANK>>,
|
||||
) -> RankedDifferentiableTagged<A, Tag, { RANK + 1 }> {
|
||||
RankedDifferentiableTagged {
|
||||
contents: DifferentiableTagged::of_vec(s.into_iter().map(|v| v.contents).collect()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_tagged<B, F>(
|
||||
self: RankedDifferentiableTagged<A, Tag, RANK>,
|
||||
f: &mut F,
|
||||
) -> RankedDifferentiableTagged<B, Tag, RANK>
|
||||
where
|
||||
F: FnMut(Scalar<A>) -> Scalar<B>,
|
||||
A: Clone,
|
||||
Tag: Clone,
|
||||
{
|
||||
RankedDifferentiableTagged {
|
||||
contents: DifferentiableTagged::map(&self.contents, f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map_tag<Tag2, F>(
|
||||
self: RankedDifferentiableTagged<A, Tag, RANK>,
|
||||
f: &mut F,
|
||||
) -> RankedDifferentiableTagged<A, Tag2, RANK>
|
||||
where
|
||||
A: Clone,
|
||||
F: FnMut(Tag) -> Tag2,
|
||||
Tag: Clone,
|
||||
{
|
||||
RankedDifferentiableTagged {
|
||||
contents: DifferentiableTagged::map_tag(&self.contents, f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn map2_tagged<B, C, Tag2, Tag3, F>(
|
||||
self: &RankedDifferentiableTagged<A, Tag, RANK>,
|
||||
other: &RankedDifferentiableTagged<B, Tag2, RANK>,
|
||||
f: &mut F,
|
||||
) -> RankedDifferentiableTagged<C, Tag3, RANK>
|
||||
where
|
||||
F: FnMut(&Scalar<A>, Tag, &Scalar<B>, Tag2) -> (Scalar<C>, Tag3),
|
||||
A: Clone,
|
||||
B: Clone,
|
||||
Tag: Clone,
|
||||
Tag2: Clone,
|
||||
{
|
||||
RankedDifferentiableTagged {
|
||||
contents: DifferentiableTagged::map2_tagged(&self.contents, &other.contents, f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vector(
|
||||
self: RankedDifferentiableTagged<A, Tag, RANK>,
|
||||
) -> Vec<RankedDifferentiableTagged<A, Tag, { RANK - 1 }>> {
|
||||
self.contents
|
||||
.into_vector()
|
||||
.into_iter()
|
||||
.map(|v| RankedDifferentiableTagged { contents: v })
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
|
||||
pub fn map<B, F>(
|
||||
self: RankedDifferentiable<A, RANK>,
|
||||
f: &mut F,
|
||||
@@ -412,9 +575,7 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
|
||||
F: FnMut(Scalar<A>) -> Scalar<B>,
|
||||
A: Clone,
|
||||
{
|
||||
RankedDifferentiable {
|
||||
contents: Differentiable::map(&self.contents, f),
|
||||
}
|
||||
self.map_tagged(f)
|
||||
}
|
||||
|
||||
pub fn map2<B, C, F>(
|
||||
@@ -427,28 +588,18 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
|
||||
A: Clone,
|
||||
B: Clone,
|
||||
{
|
||||
RankedDifferentiable {
|
||||
contents: Differentiable::map2(&self.contents, &other.contents, f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vector(
|
||||
self: RankedDifferentiable<A, RANK>,
|
||||
) -> Vec<RankedDifferentiable<A, { RANK - 1 }>> {
|
||||
self.contents
|
||||
.into_vector()
|
||||
.into_iter()
|
||||
.map(|v| RankedDifferentiable { contents: v })
|
||||
.collect()
|
||||
self.map2_tagged(other, &mut |a, (), b, ()| (f(a, b), ()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn grad<A, F, const RANK: usize, const PARAM_RANK: usize>(
|
||||
pub fn grad<A, Tag, F, const RANK: usize, const PARAM_RANK: usize>(
|
||||
mut f: F,
|
||||
theta: &[Differentiable<A>; PARAM_RANK],
|
||||
) -> [Differentiable<A>; PARAM_RANK]
|
||||
theta: &[DifferentiableTagged<A, Tag>; PARAM_RANK],
|
||||
) -> [DifferentiableTagged<A, Tag>; PARAM_RANK]
|
||||
where
|
||||
F: FnMut(&[Differentiable<A>; PARAM_RANK]) -> RankedDifferentiable<A, RANK>,
|
||||
F: FnMut(
|
||||
&[DifferentiableTagged<A, Tag>; PARAM_RANK],
|
||||
) -> RankedDifferentiableTagged<A, Tag, RANK>,
|
||||
A: ?Sized
|
||||
+ Clone
|
||||
+ Hash
|
||||
@@ -460,6 +611,7 @@ where
|
||||
+ One
|
||||
+ Neg<Output = A>
|
||||
+ Eq,
|
||||
Tag: Clone,
|
||||
{
|
||||
let mut i = 0usize;
|
||||
let wrt = theta.each_ref().map(|theta| {
|
||||
@@ -482,19 +634,19 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
fn extract_scalar<A>(d: &Differentiable<A>) -> &A {
|
||||
fn extract_scalar<A, Tag>(d: &DifferentiableTagged<A, Tag>) -> &A {
|
||||
d.borrow_scalar().real_part()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_map() {
|
||||
let v = Differentiable::of_vec(
|
||||
let v = DifferentiableTagged::of_vec(
|
||||
vec![
|
||||
Differentiable::of_scalar(Scalar::Number(
|
||||
NotNan::new(3.0).expect("3 is not NaN"),
|
||||
Some(0usize),
|
||||
)),
|
||||
Differentiable::of_scalar(Scalar::Number(
|
||||
DifferentiableTagged::of_scalar(Scalar::Number(
|
||||
NotNan::new(4.0).expect("4 is not NaN"),
|
||||
Some(1usize),
|
||||
)),
|
||||
@@ -518,38 +670,40 @@ mod tests {
|
||||
#[test]
|
||||
fn test_autodiff() {
|
||||
let input_vec = [
|
||||
RankedDifferentiable::of_scalar(Scalar::<NotNan<f64>>::zero()).contents,
|
||||
RankedDifferentiable::of_scalar(Scalar::<NotNan<f64>>::zero()).contents,
|
||||
RankedDifferentiableTagged::of_scalar(Scalar::<NotNan<f64>>::zero()).contents,
|
||||
RankedDifferentiableTagged::of_scalar(Scalar::<NotNan<f64>>::zero()).contents,
|
||||
];
|
||||
let xs = [2.0, 1.0, 4.0, 3.0].map(|x| NotNan::new(x).expect("not nan"));
|
||||
let ys = [1.8, 1.2, 4.2, 3.3].map(|x| NotNan::new(x).expect("not nan"));
|
||||
let grad = grad(
|
||||
|x| {
|
||||
RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(l2_loss_2(
|
||||
predict_line_2_unranked,
|
||||
RankedDifferentiable::of_slice(xs.iter()),
|
||||
RankedDifferentiable::of_slice(ys.iter()),
|
||||
x,
|
||||
))])
|
||||
RankedDifferentiableTagged::of_vector(vec![RankedDifferentiable::of_scalar(
|
||||
l2_loss_2(
|
||||
predict_line_2_unranked,
|
||||
RankedDifferentiableTagged::of_slice(xs.iter()),
|
||||
RankedDifferentiableTagged::of_slice(ys.iter()),
|
||||
x,
|
||||
),
|
||||
)])
|
||||
},
|
||||
&input_vec,
|
||||
);
|
||||
|
||||
let grad_vec = grad
|
||||
.map(Differentiable::into_scalar)
|
||||
.map(DifferentiableTagged::into_scalar)
|
||||
.map(|x| f64::from(*x.real_part()));
|
||||
assert_eq!(grad_vec, [-63.0, -21.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn grad_example() {
|
||||
let input_vec = [Differentiable::of_scalar(Scalar::make(
|
||||
let input_vec = [DifferentiableTagged::of_scalar(Scalar::make(
|
||||
NotNan::new(27.0).expect("not nan"),
|
||||
))];
|
||||
|
||||
let grad: Vec<_> = grad(
|
||||
|x| {
|
||||
RankedDifferentiable::of_scalar(
|
||||
RankedDifferentiableTagged::of_scalar(
|
||||
x[0].borrow_scalar().clone() * x[0].borrow_scalar().clone(),
|
||||
)
|
||||
},
|
||||
@@ -565,19 +719,21 @@ mod tests {
|
||||
fn loss_gradient() {
|
||||
let zero = Scalar::<NotNan<f64>>::zero();
|
||||
let input_vec = [
|
||||
RankedDifferentiable::of_scalar(zero.clone()).to_unranked(),
|
||||
RankedDifferentiable::of_scalar(zero).to_unranked(),
|
||||
RankedDifferentiableTagged::of_scalar(zero.clone()).to_unranked(),
|
||||
RankedDifferentiableTagged::of_scalar(zero).to_unranked(),
|
||||
];
|
||||
let xs = to_not_nan_1([2.0, 1.0, 4.0, 3.0]);
|
||||
let ys = to_not_nan_1([1.8, 1.2, 4.2, 3.3]);
|
||||
let grad = grad(
|
||||
|x| {
|
||||
RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(l2_loss_2(
|
||||
predict_line_2_unranked,
|
||||
RankedDifferentiable::of_slice(&xs),
|
||||
RankedDifferentiable::of_slice(&ys),
|
||||
x,
|
||||
))])
|
||||
RankedDifferentiableTagged::of_vector(vec![RankedDifferentiableTagged::of_scalar(
|
||||
l2_loss_2(
|
||||
predict_line_2_unranked,
|
||||
RankedDifferentiableTagged::of_slice(&xs),
|
||||
RankedDifferentiableTagged::of_slice(&ys),
|
||||
x,
|
||||
),
|
||||
)])
|
||||
},
|
||||
&input_vec,
|
||||
);
|
||||
|
@@ -3,9 +3,10 @@ use std::{
|
||||
ops::{Add, Mul, Neg},
|
||||
};
|
||||
|
||||
use crate::auto_diff::Differentiable;
|
||||
use crate::traits::NumLike;
|
||||
use crate::{
|
||||
auto_diff::{Differentiable, RankedDifferentiable},
|
||||
auto_diff::{DifferentiableTagged, RankedDifferentiable},
|
||||
scalar::Scalar,
|
||||
traits::{One, Zero},
|
||||
};
|
||||
@@ -27,11 +28,27 @@ where
|
||||
RankedDifferentiable::map2(x, y, &mut |x, y| x.clone() * y.clone())
|
||||
}
|
||||
|
||||
pub fn dot_unranked_tagged<A, Tag1, Tag2, Tag3, F>(
|
||||
x: &DifferentiableTagged<A, Tag1>,
|
||||
y: &DifferentiableTagged<A, Tag2>,
|
||||
mut combine_tags: F,
|
||||
) -> DifferentiableTagged<A, Tag3>
|
||||
where
|
||||
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Clone + Default,
|
||||
F: FnMut(Tag1, Tag2) -> Tag3,
|
||||
Tag1: Clone,
|
||||
Tag2: Clone,
|
||||
{
|
||||
DifferentiableTagged::map2_tagged(x, y, &mut |x, tag1, y, tag2| {
|
||||
(x.clone() * y.clone(), combine_tags(tag1, tag2))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dot_unranked<A>(x: &Differentiable<A>, y: &Differentiable<A>) -> Differentiable<A>
|
||||
where
|
||||
A: Mul<Output = A> + Sum<<A as Mul>::Output> + Clone + Default,
|
||||
{
|
||||
Differentiable::map2(x, y, &mut |x, y| x.clone() * y.clone())
|
||||
dot_unranked_tagged(x, y, |(), ()| ())
|
||||
}
|
||||
|
||||
fn squared_2<A, const RANK: usize>(
|
||||
@@ -127,7 +144,7 @@ where
|
||||
let dotted = RankedDifferentiable::of_scalar(
|
||||
dot_unranked(
|
||||
left_arg.to_unranked_borrow(),
|
||||
&Differentiable::of_vec(theta.to_vec()),
|
||||
&DifferentiableTagged::of_vec(theta.to_vec()),
|
||||
)
|
||||
.into_vector()
|
||||
.into_iter()
|
||||
@@ -181,7 +198,7 @@ where
|
||||
);
|
||||
dot_unranked(
|
||||
x_powers.to_unranked_borrow(),
|
||||
&Differentiable::of_vec(theta.to_vec()),
|
||||
&DifferentiableTagged::of_vec(theta.to_vec()),
|
||||
)
|
||||
.attach_rank::<1>()
|
||||
.expect("wanted a tensor1")
|
||||
@@ -220,10 +237,11 @@ where
|
||||
RankedDifferentiable::of_vector(dotted)
|
||||
}
|
||||
|
||||
pub struct Predictor<F, Inflated, Deflated> {
|
||||
pub struct Predictor<F, Inflated, Deflated, Params> {
|
||||
pub predict: F,
|
||||
pub inflate: fn(Deflated) -> Inflated,
|
||||
pub deflate: fn(Inflated) -> Deflated,
|
||||
pub update: fn(Inflated, &Deflated, Params) -> Inflated,
|
||||
}
|
||||
|
||||
type ParameterPredictor<T, const INPUT_DIM: usize, const THETA: usize> =
|
||||
@@ -232,39 +250,91 @@ type ParameterPredictor<T, const INPUT_DIM: usize, const THETA: usize> =
|
||||
&[Differentiable<T>; THETA],
|
||||
) -> RankedDifferentiable<T, 1>;
|
||||
|
||||
pub const fn plane_predictor<T>() -> Predictor<ParameterPredictor<T, 2, 2>, Scalar<T>, Scalar<T>>
|
||||
#[derive(Clone)]
|
||||
pub struct NakedHypers<A> {
|
||||
pub learning_rate: A,
|
||||
}
|
||||
|
||||
pub const fn naked_predictor<F, A>(
|
||||
f: F,
|
||||
) -> Predictor<F, Differentiable<A>, Differentiable<A>, NakedHypers<A>>
|
||||
where
|
||||
A: NumLike,
|
||||
{
|
||||
Predictor {
|
||||
predict: f,
|
||||
inflate: |x| x,
|
||||
deflate: |x| x,
|
||||
|
||||
update: |theta, delta, hyper| {
|
||||
let learning_rate = Scalar::make(hyper.learning_rate);
|
||||
Differentiable::map2(&theta, delta, &mut |theta, delta| {
|
||||
theta.clone() - delta.clone() * learning_rate.clone()
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct VelocityHypers<A> {
|
||||
pub learning_rate: A,
|
||||
pub mu: A,
|
||||
}
|
||||
|
||||
pub const fn velocity_predictor<F, A>(
|
||||
f: F,
|
||||
) -> Predictor<F, DifferentiableTagged<A, A>, Differentiable<A>, VelocityHypers<A>>
|
||||
where
|
||||
A: NumLike,
|
||||
{
|
||||
Predictor {
|
||||
predict: f,
|
||||
inflate: |x| x.map_tag(&mut |()| A::zero()),
|
||||
deflate: |x| x.map_tag(&mut |_| ()),
|
||||
update: |theta, delta, hyper| {
|
||||
DifferentiableTagged::map2_tagged(&theta, delta, &mut |theta, velocity, delta, ()| {
|
||||
let velocity = hyper.mu.clone() * velocity
|
||||
+ -(delta.clone_real_part() * hyper.learning_rate.clone());
|
||||
(theta.clone() + Scalar::make(velocity.clone()), velocity)
|
||||
})
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn plane_predictor<T>(
|
||||
) -> Predictor<ParameterPredictor<T, 2, 2>, Differentiable<T>, Differentiable<T>, NakedHypers<T>>
|
||||
where
|
||||
T: NumLike + Default,
|
||||
{
|
||||
Predictor {
|
||||
predict: predict_plane,
|
||||
inflate: |x| x,
|
||||
deflate: |x| x,
|
||||
}
|
||||
naked_predictor(predict_plane)
|
||||
}
|
||||
|
||||
pub const fn velocity_plane_predictor<T>() -> Predictor<
|
||||
ParameterPredictor<T, 2, 2>,
|
||||
DifferentiableTagged<T, T>,
|
||||
Differentiable<T>,
|
||||
VelocityHypers<T>,
|
||||
>
|
||||
where
|
||||
T: NumLike + Default,
|
||||
{
|
||||
velocity_predictor(predict_plane)
|
||||
}
|
||||
|
||||
pub const fn line_unranked_predictor<T>(
|
||||
) -> Predictor<ParameterPredictor<T, 1, 2>, Scalar<T>, Scalar<T>>
|
||||
) -> Predictor<ParameterPredictor<T, 1, 2>, Differentiable<T>, Differentiable<T>, NakedHypers<T>>
|
||||
where
|
||||
T: NumLike + Default + Copy,
|
||||
{
|
||||
Predictor {
|
||||
predict: predict_line_2_unranked,
|
||||
inflate: |x| x,
|
||||
deflate: |x| x,
|
||||
}
|
||||
naked_predictor(predict_line_2_unranked)
|
||||
}
|
||||
|
||||
pub const fn quadratic_unranked_predictor<T>(
|
||||
) -> Predictor<ParameterPredictor<T, 1, 3>, Scalar<T>, Scalar<T>>
|
||||
) -> Predictor<ParameterPredictor<T, 1, 3>, Differentiable<T>, Differentiable<T>, NakedHypers<T>>
|
||||
where
|
||||
T: NumLike + Default,
|
||||
{
|
||||
Predictor {
|
||||
predict: predict_quadratic_unranked,
|
||||
inflate: |x| x,
|
||||
deflate: |x| x,
|
||||
}
|
||||
naked_predictor(predict_quadratic_unranked)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@@ -7,10 +7,14 @@ mod with_tensor;
|
||||
use core::hash::Hash;
|
||||
use rand::Rng;
|
||||
|
||||
use little_learner::auto_diff::{grad, Differentiable, RankedDifferentiable};
|
||||
use little_learner::auto_diff::{
|
||||
grad, Differentiable, RankedDifferentiable, RankedDifferentiableTagged,
|
||||
};
|
||||
|
||||
use crate::sample::sample2;
|
||||
use little_learner::loss::{l2_loss_2, plane_predictor, Predictor};
|
||||
use little_learner::loss::{
|
||||
l2_loss_2, velocity_plane_predictor, NakedHypers, Predictor, VelocityHypers,
|
||||
};
|
||||
use little_learner::not_nan::{to_not_nan_1, to_not_nan_2};
|
||||
use little_learner::scalar::Scalar;
|
||||
use little_learner::traits::{NumLike, Zero};
|
||||
@@ -27,43 +31,101 @@ where
|
||||
v
|
||||
}
|
||||
|
||||
struct GradientDescentHyper<A, R: Rng> {
|
||||
#[derive(Clone)]
|
||||
struct GradientDescentHyperImmut<A> {
|
||||
learning_rate: A,
|
||||
iterations: u32,
|
||||
sampling: Option<(R, usize)>,
|
||||
mu: A,
|
||||
}
|
||||
|
||||
fn gradient_descent_step<A, F, const RANK: usize, const PARAM_NUM: usize>(
|
||||
struct GradientDescentHyper<A, R: Rng> {
|
||||
sampling: Option<(R, usize)>,
|
||||
params: GradientDescentHyperImmut<A>,
|
||||
}
|
||||
|
||||
impl<A> From<GradientDescentHyperImmut<A>> for VelocityHypers<A> {
|
||||
fn from(val: GradientDescentHyperImmut<A>) -> VelocityHypers<A> {
|
||||
VelocityHypers {
|
||||
learning_rate: val.learning_rate,
|
||||
mu: val.mu,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> From<GradientDescentHyperImmut<A>> for NakedHypers<A> {
|
||||
fn from(val: GradientDescentHyperImmut<A>) -> NakedHypers<A> {
|
||||
NakedHypers {
|
||||
learning_rate: val.learning_rate,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> GradientDescentHyper<A, rand::rngs::StdRng> {
|
||||
fn new(learning_rate: A, iterations: u32, mu: A) -> Self {
|
||||
GradientDescentHyper {
|
||||
params: GradientDescentHyperImmut {
|
||||
learning_rate,
|
||||
iterations,
|
||||
mu,
|
||||
},
|
||||
sampling: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `adjust` takes the previous value and a delta, and returns a deflated new value.
|
||||
fn general_gradient_descent_step<
|
||||
A,
|
||||
F,
|
||||
Inflated,
|
||||
Deflate,
|
||||
Adjust,
|
||||
Hyper,
|
||||
const RANK: usize,
|
||||
const PARAM_NUM: usize,
|
||||
>(
|
||||
f: &mut F,
|
||||
theta: [Differentiable<A>; PARAM_NUM],
|
||||
learning_rate: A,
|
||||
) -> [Differentiable<A>; PARAM_NUM]
|
||||
theta: [Inflated; PARAM_NUM],
|
||||
deflate: Deflate,
|
||||
hyper: Hyper,
|
||||
mut adjust: Adjust,
|
||||
) -> [Inflated; PARAM_NUM]
|
||||
where
|
||||
A: Clone + NumLike + Hash + Eq,
|
||||
F: FnMut(&[Differentiable<A>; PARAM_NUM]) -> RankedDifferentiable<A, RANK>,
|
||||
Deflate: FnMut(Inflated) -> Differentiable<A>,
|
||||
Inflated: Clone,
|
||||
Hyper: Clone,
|
||||
Adjust: FnMut(Inflated, &Differentiable<A>, Hyper) -> Inflated,
|
||||
{
|
||||
let delta = grad(f, &theta);
|
||||
let deflated = theta.clone().map(deflate);
|
||||
let delta = grad(f, &deflated);
|
||||
let mut i = 0;
|
||||
theta.map(|theta| {
|
||||
theta.map(|inflated| {
|
||||
let delta = &delta[i];
|
||||
i += 1;
|
||||
// For speed, you might want to truncate_dual this.
|
||||
let learning_rate = Scalar::make(learning_rate.clone());
|
||||
Differentiable::map2(
|
||||
&theta,
|
||||
&delta.map(&mut |s| s * learning_rate.clone()),
|
||||
&mut |theta, delta| (*theta).clone() - (*delta).clone(),
|
||||
)
|
||||
adjust(inflated, delta, hyper.clone())
|
||||
})
|
||||
}
|
||||
|
||||
fn gradient_descent<'a, T, R: Rng, Point, F, G, const IN_SIZE: usize, const PARAM_NUM: usize>(
|
||||
mut hyper: GradientDescentHyper<T, R>,
|
||||
fn gradient_descent<
|
||||
'a,
|
||||
T,
|
||||
R: Rng,
|
||||
Point,
|
||||
F,
|
||||
G,
|
||||
Inflated,
|
||||
Hyper,
|
||||
const IN_SIZE: usize,
|
||||
const PARAM_NUM: usize,
|
||||
>(
|
||||
hyper: &mut GradientDescentHyper<T, R>,
|
||||
xs: &'a [Point],
|
||||
to_ranked_differentiable: G,
|
||||
ys: &[T],
|
||||
zero_params: [Differentiable<T>; PARAM_NUM],
|
||||
mut predictor: Predictor<F, Scalar<T>, Scalar<T>>,
|
||||
mut predictor: Predictor<F, Inflated, Differentiable<T>, Hyper>,
|
||||
) -> [Differentiable<T>; PARAM_NUM]
|
||||
where
|
||||
T: NumLike + Hash + Copy + Default,
|
||||
@@ -73,11 +135,14 @@ where
|
||||
&[Differentiable<T>; PARAM_NUM],
|
||||
) -> RankedDifferentiable<T, 1>,
|
||||
G: for<'b> Fn(&'b [Point]) -> RankedDifferentiable<T, IN_SIZE>,
|
||||
Inflated: Clone,
|
||||
Hyper: Clone,
|
||||
GradientDescentHyperImmut<T>: Into<Hyper>,
|
||||
{
|
||||
let iterations = hyper.iterations;
|
||||
iterate(
|
||||
let iterations = hyper.params.iterations;
|
||||
let out = iterate(
|
||||
|theta| {
|
||||
let out = gradient_descent_step::<T, _, 1, PARAM_NUM>(
|
||||
general_gradient_descent_step(
|
||||
&mut |x| match hyper.sampling.as_mut() {
|
||||
None => RankedDifferentiable::of_vector(vec![RankedDifferentiable::of_scalar(
|
||||
l2_loss_2(
|
||||
@@ -99,14 +164,16 @@ where
|
||||
)])
|
||||
}
|
||||
},
|
||||
theta.map(|x| x.map(&mut predictor.inflate)),
|
||||
hyper.learning_rate,
|
||||
);
|
||||
out.map(|x| x.map(&mut predictor.deflate))
|
||||
theta,
|
||||
predictor.deflate,
|
||||
hyper.params.clone().into(),
|
||||
predictor.update,
|
||||
)
|
||||
},
|
||||
zero_params,
|
||||
zero_params.map(predictor.inflate),
|
||||
iterations,
|
||||
)
|
||||
);
|
||||
out.map(&mut predictor.deflate)
|
||||
}
|
||||
|
||||
fn collect_vec<T>(input: RankedDifferentiable<NotNan<T>, 1>) -> Vec<T>
|
||||
@@ -131,11 +198,11 @@ fn main() {
|
||||
];
|
||||
let plane_ys = [13.99, 15.99, 18.0, 22.4, 30.2, 37.94];
|
||||
|
||||
let hyper = GradientDescentHyper {
|
||||
learning_rate: NotNan::new(0.001).expect("not nan"),
|
||||
iterations: 1000,
|
||||
sampling: None::<(rand::rngs::StdRng, _)>,
|
||||
};
|
||||
let mut hyper = GradientDescentHyper::new(
|
||||
NotNan::new(0.001).expect("not nan"),
|
||||
1000,
|
||||
NotNan::new(0.9).expect("not nan"),
|
||||
);
|
||||
|
||||
let iterated = {
|
||||
let xs = to_not_nan_2(plane_xs);
|
||||
@@ -147,12 +214,12 @@ fn main() {
|
||||
];
|
||||
|
||||
gradient_descent(
|
||||
hyper,
|
||||
&mut hyper,
|
||||
&xs,
|
||||
RankedDifferentiable::of_slice_2::<_, 2>,
|
||||
RankedDifferentiableTagged::of_slice_2::<_, 2>,
|
||||
&ys,
|
||||
zero_params,
|
||||
plane_predictor(),
|
||||
velocity_plane_predictor(),
|
||||
)
|
||||
};
|
||||
|
||||
@@ -161,17 +228,19 @@ fn main() {
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
|
||||
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
|
||||
|
||||
assert_eq!(collect_vec(theta0), [3.97757644609063, 2.0496557321494446]);
|
||||
assert_eq!(collect_vec(theta0), [3.979645447136021, 1.976454920954754]);
|
||||
assert_eq!(
|
||||
theta1.to_scalar().real_part().into_inner(),
|
||||
5.786758464448078
|
||||
6.169579045974949
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use little_learner::loss::{line_unranked_predictor, quadratic_unranked_predictor};
|
||||
use little_learner::loss::{
|
||||
line_unranked_predictor, plane_predictor, quadratic_unranked_predictor,
|
||||
};
|
||||
use rand::SeedableRng;
|
||||
|
||||
#[test]
|
||||
@@ -187,11 +256,11 @@ mod tests {
|
||||
|
||||
let zero = Scalar::<NotNan<f64>>::zero();
|
||||
|
||||
let hyper = GradientDescentHyper {
|
||||
learning_rate: NotNan::new(0.01).expect("not nan"),
|
||||
iterations: 1000,
|
||||
sampling: None::<(rand::rngs::StdRng, _)>,
|
||||
};
|
||||
let mut hyper = GradientDescentHyper::new(
|
||||
NotNan::new(0.01).expect("not nan"),
|
||||
1000,
|
||||
NotNan::new(0.0).expect("not nan"),
|
||||
);
|
||||
let iterated = {
|
||||
let xs = to_not_nan_1(xs);
|
||||
let ys = to_not_nan_1(ys);
|
||||
@@ -200,7 +269,7 @@ mod tests {
|
||||
RankedDifferentiable::of_scalar(zero).to_unranked(),
|
||||
];
|
||||
gradient_descent(
|
||||
hyper,
|
||||
&mut hyper,
|
||||
&xs,
|
||||
|b| RankedDifferentiable::of_slice(b),
|
||||
&ys,
|
||||
@@ -223,11 +292,11 @@ mod tests {
|
||||
|
||||
let zero = Scalar::<NotNan<f64>>::zero();
|
||||
|
||||
let hyper = GradientDescentHyper {
|
||||
learning_rate: NotNan::new(0.001).expect("not nan"),
|
||||
iterations: 1000,
|
||||
sampling: None::<(rand::rngs::StdRng, _)>,
|
||||
};
|
||||
let mut hyper = GradientDescentHyper::new(
|
||||
NotNan::new(0.001).expect("not nan"),
|
||||
1000,
|
||||
NotNan::new(0.0).expect("not nan"),
|
||||
);
|
||||
|
||||
let iterated = {
|
||||
let xs = to_not_nan_1(xs);
|
||||
@@ -238,7 +307,7 @@ mod tests {
|
||||
RankedDifferentiable::of_scalar(zero).to_unranked(),
|
||||
];
|
||||
gradient_descent(
|
||||
hyper,
|
||||
&mut hyper,
|
||||
&xs,
|
||||
|b| RankedDifferentiable::of_slice(b),
|
||||
&ys,
|
||||
@@ -269,11 +338,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn optimise_plane() {
|
||||
let hyper = GradientDescentHyper {
|
||||
learning_rate: NotNan::new(0.001).expect("not nan"),
|
||||
iterations: 1000,
|
||||
sampling: None::<(rand::rngs::StdRng, _)>,
|
||||
};
|
||||
let mut hyper = GradientDescentHyper::new(
|
||||
NotNan::new(0.001).expect("not nan"),
|
||||
1000,
|
||||
NotNan::new(0.0).expect("not nan"),
|
||||
);
|
||||
|
||||
let iterated = {
|
||||
let xs = to_not_nan_2(PLANE_XS);
|
||||
@@ -283,7 +352,7 @@ mod tests {
|
||||
Differentiable::of_scalar(Scalar::zero()),
|
||||
];
|
||||
gradient_descent(
|
||||
hyper,
|
||||
&mut hyper,
|
||||
&xs,
|
||||
RankedDifferentiable::of_slice_2::<_, 2>,
|
||||
&ys,
|
||||
@@ -307,9 +376,12 @@ mod tests {
|
||||
#[test]
|
||||
fn optimise_plane_with_sampling() {
|
||||
let rng = rand::rngs::StdRng::seed_from_u64(314159);
|
||||
let hyper = GradientDescentHyper {
|
||||
learning_rate: NotNan::new(0.001).expect("not nan"),
|
||||
iterations: 1000,
|
||||
let mut hyper = GradientDescentHyper {
|
||||
params: GradientDescentHyperImmut {
|
||||
learning_rate: NotNan::new(0.001).expect("not nan"),
|
||||
iterations: 1000,
|
||||
mu: NotNan::new(0.0).expect("not nan"),
|
||||
},
|
||||
sampling: Some((rng, 4)),
|
||||
};
|
||||
|
||||
@@ -321,7 +393,7 @@ mod tests {
|
||||
Differentiable::of_scalar(Scalar::zero()),
|
||||
];
|
||||
gradient_descent(
|
||||
hyper,
|
||||
&mut hyper,
|
||||
&xs,
|
||||
RankedDifferentiable::of_slice_2::<_, 2>,
|
||||
&ys,
|
||||
@@ -367,4 +439,43 @@ mod tests {
|
||||
assert_eq!(theta0, [3.8581694055684781, 2.2166222673968554]);
|
||||
assert_eq!(theta1, 5.2839863438547159);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_velocity() {
|
||||
let mut hyper = GradientDescentHyper::new(
|
||||
NotNan::new(0.001).expect("not nan"),
|
||||
1000,
|
||||
NotNan::new(0.9).expect("not nan"),
|
||||
);
|
||||
|
||||
let iterated = {
|
||||
let xs = to_not_nan_2(PLANE_XS);
|
||||
let ys = to_not_nan_1(PLANE_YS);
|
||||
let zero_params = [
|
||||
RankedDifferentiable::of_slice(&[NotNan::<f64>::zero(), NotNan::<f64>::zero()])
|
||||
.to_unranked(),
|
||||
Differentiable::of_scalar(Scalar::zero()),
|
||||
];
|
||||
|
||||
gradient_descent(
|
||||
&mut hyper,
|
||||
&xs,
|
||||
RankedDifferentiableTagged::of_slice_2::<_, 2>,
|
||||
&ys,
|
||||
zero_params,
|
||||
velocity_plane_predictor(),
|
||||
)
|
||||
};
|
||||
|
||||
let [theta0, theta1] = iterated;
|
||||
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
|
||||
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
|
||||
|
||||
assert_eq!(collect_vec(theta0), [3.979645447136021, 1.976454920954754]);
|
||||
assert_eq!(
|
||||
theta1.to_scalar().real_part().into_inner(),
|
||||
6.169579045974949
|
||||
);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user