Little refactor allowing speed (#10)
This commit is contained in:
@@ -257,7 +257,7 @@ impl<A, const RANK: usize> Differentiable<A, RANK> {
|
||||
{
|
||||
let mut i = 0usize;
|
||||
let wrt = theta.contents.map(&mut |x| {
|
||||
let result = Scalar::truncate_dual(x, i);
|
||||
let result = Scalar::truncate_dual(x, Some(i));
|
||||
i += 1;
|
||||
result
|
||||
});
|
||||
|
@@ -231,11 +231,11 @@ impl<A> Scalar<A> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn truncate_dual(self, index: usize) -> Scalar<A>
|
||||
pub fn truncate_dual(self, index: Option<usize>) -> Scalar<A>
|
||||
where
|
||||
A: Clone,
|
||||
{
|
||||
Scalar::Dual(self.clone_real_part(), Link::EndOfLink(Some(index)))
|
||||
Scalar::Dual(self.clone_real_part(), Link::EndOfLink(index))
|
||||
}
|
||||
|
||||
pub fn make(x: A) -> Scalar<A> {
|
||||
|
@@ -17,10 +17,11 @@ fn iterate<A, F>(f: &F, start: A, n: u32) -> A
|
||||
where
|
||||
F: Fn(A) -> A,
|
||||
{
|
||||
if n == 0 {
|
||||
return start;
|
||||
let mut v = start;
|
||||
for _ in 0..n {
|
||||
v = f(v);
|
||||
}
|
||||
iterate(f, f(start), n - 1)
|
||||
v
|
||||
}
|
||||
|
||||
struct GradientDescentHyper<A, const RANK: usize> {
|
||||
|
Reference in New Issue
Block a user