Little refactor allowing speed (#10)

This commit is contained in:
Patrick Stevens
2023-04-07 11:37:05 +01:00
committed by GitHub
parent 817775412b
commit 3c964bc132
3 changed files with 7 additions and 6 deletions

View File

@@ -257,7 +257,7 @@ impl<A, const RANK: usize> Differentiable<A, RANK> {
{
let mut i = 0usize;
let wrt = theta.contents.map(&mut |x| {
let result = Scalar::truncate_dual(x, i);
let result = Scalar::truncate_dual(x, Some(i));
i += 1;
result
});

View File

@@ -231,11 +231,11 @@ impl<A> Scalar<A> {
}
}
pub fn truncate_dual(self, index: usize) -> Scalar<A>
pub fn truncate_dual(self, index: Option<usize>) -> Scalar<A>
where
A: Clone,
{
Scalar::Dual(self.clone_real_part(), Link::EndOfLink(Some(index)))
Scalar::Dual(self.clone_real_part(), Link::EndOfLink(index))
}
pub fn make(x: A) -> Scalar<A> {

View File

@@ -17,10 +17,11 @@ fn iterate<A, F>(f: &F, start: A, n: u32) -> A
where
F: Fn(A) -> A,
{
if n == 0 {
return start;
let mut v = start;
for _ in 0..n {
v = f(v);
}
iterate(f, f(start), n - 1)
v
}
struct GradientDescentHyper<A, const RANK: usize> {