Smoothing (#17)
This commit is contained in:
@@ -176,17 +176,6 @@ where
|
||||
out.map(&mut predictor.deflate)
|
||||
}
|
||||
|
||||
fn collect_vec<T>(input: RankedDifferentiable<NotNan<T>, 1>) -> Vec<T>
|
||||
where
|
||||
T: Copy,
|
||||
{
|
||||
input
|
||||
.to_vector()
|
||||
.into_iter()
|
||||
.map(|x| x.to_scalar().real_part().into_inner())
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let plane_xs = [
|
||||
[1.0, 2.05],
|
||||
@@ -228,7 +217,7 @@ fn main() {
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
|
||||
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
|
||||
|
||||
assert_eq!(collect_vec(theta0), [3.979645447136021, 1.976454920954754]);
|
||||
assert_eq!(theta0.collect(), [3.979645447136021, 1.976454920954754]);
|
||||
assert_eq!(
|
||||
theta1.to_scalar().real_part().into_inner(),
|
||||
6.169579045974949
|
||||
@@ -366,7 +355,7 @@ mod tests {
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
|
||||
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
|
||||
|
||||
assert_eq!(collect_vec(theta0), [3.97757644609063, 2.0496557321494446]);
|
||||
assert_eq!(theta0.collect(), [3.97757644609063, 2.0496557321494446]);
|
||||
assert_eq!(
|
||||
theta1.to_scalar().real_part().into_inner(),
|
||||
5.786758464448078
|
||||
@@ -404,7 +393,7 @@ mod tests {
|
||||
|
||||
let [theta0, theta1] = iterated;
|
||||
|
||||
let theta0 = collect_vec(theta0.attach_rank::<1>().expect("rank 1 tensor"));
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor").collect();
|
||||
let theta1 = theta1
|
||||
.attach_rank::<0>()
|
||||
.expect("rank 0 tensor")
|
||||
@@ -472,7 +461,7 @@ mod tests {
|
||||
let theta0 = theta0.attach_rank::<1>().expect("rank 1 tensor");
|
||||
let theta1 = theta1.attach_rank::<0>().expect("rank 0 tensor");
|
||||
|
||||
assert_eq!(collect_vec(theta0), [3.979645447136021, 1.976454920954754]);
|
||||
assert_eq!(theta0.collect(), [3.979645447136021, 1.976454920954754]);
|
||||
assert_eq!(
|
||||
theta1.to_scalar().real_part().into_inner(),
|
||||
6.169579045974949
|
||||
|
Reference in New Issue
Block a user