Little refactor allowing speed (#10)

This commit is contained in:
Patrick Stevens
2023-04-07 11:37:05 +01:00
committed by GitHub
parent 817775412b
commit 3c964bc132
3 changed files with 7 additions and 6 deletions

View File

@@ -257,7 +257,7 @@ impl<A, const RANK: usize> Differentiable<A, RANK> {
{ {
let mut i = 0usize; let mut i = 0usize;
let wrt = theta.contents.map(&mut |x| { let wrt = theta.contents.map(&mut |x| {
let result = Scalar::truncate_dual(x, i); let result = Scalar::truncate_dual(x, Some(i));
i += 1; i += 1;
result result
}); });

View File

@@ -231,11 +231,11 @@ impl<A> Scalar<A> {
} }
} }
pub fn truncate_dual(self, index: usize) -> Scalar<A> pub fn truncate_dual(self, index: Option<usize>) -> Scalar<A>
where where
A: Clone, A: Clone,
{ {
Scalar::Dual(self.clone_real_part(), Link::EndOfLink(Some(index))) Scalar::Dual(self.clone_real_part(), Link::EndOfLink(index))
} }
pub fn make(x: A) -> Scalar<A> { pub fn make(x: A) -> Scalar<A> {

View File

@@ -17,10 +17,11 @@ fn iterate<A, F>(f: &F, start: A, n: u32) -> A
where where
F: Fn(A) -> A, F: Fn(A) -> A,
{ {
if n == 0 { let mut v = start;
return start; for _ in 0..n {
v = f(v);
} }
iterate(f, f(start), n - 1) v
} }
struct GradientDescentHyper<A, const RANK: usize> { struct GradientDescentHyper<A, const RANK: usize> {