Little refactor allowing speed (#10)
This commit is contained in:
@@ -257,7 +257,7 @@ impl<A, const RANK: usize> Differentiable<A, RANK> {
|
|||||||
{
|
{
|
||||||
let mut i = 0usize;
|
let mut i = 0usize;
|
||||||
let wrt = theta.contents.map(&mut |x| {
|
let wrt = theta.contents.map(&mut |x| {
|
||||||
let result = Scalar::truncate_dual(x, i);
|
let result = Scalar::truncate_dual(x, Some(i));
|
||||||
i += 1;
|
i += 1;
|
||||||
result
|
result
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -231,11 +231,11 @@ impl<A> Scalar<A> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn truncate_dual(self, index: usize) -> Scalar<A>
|
pub fn truncate_dual(self, index: Option<usize>) -> Scalar<A>
|
||||||
where
|
where
|
||||||
A: Clone,
|
A: Clone,
|
||||||
{
|
{
|
||||||
Scalar::Dual(self.clone_real_part(), Link::EndOfLink(Some(index)))
|
Scalar::Dual(self.clone_real_part(), Link::EndOfLink(index))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn make(x: A) -> Scalar<A> {
|
pub fn make(x: A) -> Scalar<A> {
|
||||||
|
|||||||
@@ -17,10 +17,11 @@ fn iterate<A, F>(f: &F, start: A, n: u32) -> A
|
|||||||
where
|
where
|
||||||
F: Fn(A) -> A,
|
F: Fn(A) -> A,
|
||||||
{
|
{
|
||||||
if n == 0 {
|
let mut v = start;
|
||||||
return start;
|
for _ in 0..n {
|
||||||
|
v = f(v);
|
||||||
}
|
}
|
||||||
iterate(f, f(start), n - 1)
|
v
|
||||||
}
|
}
|
||||||
|
|
||||||
struct GradientDescentHyper<A, const RANK: usize> {
|
struct GradientDescentHyper<A, const RANK: usize> {
|
||||||
|
|||||||
Reference in New Issue
Block a user