Separate implementation (#12)
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -101,9 +101,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.53"
|
version = "1.0.56"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ba466839c78239c09faf015484e5cc04860f88242cff4d03eb038f04b4699b73"
|
checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
@@ -12,7 +12,9 @@ where
|
|||||||
A: Zero,
|
A: Zero,
|
||||||
{
|
{
|
||||||
fn zero() -> Differentiable<A> {
|
fn zero() -> Differentiable<A> {
|
||||||
Differentiable::Scalar(Scalar::Number(A::zero(), None))
|
Differentiable {
|
||||||
|
contents: DifferentiableContents::Scalar(Scalar::Number(A::zero(), None)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -30,7 +32,21 @@ where
|
|||||||
A: One,
|
A: One,
|
||||||
{
|
{
|
||||||
fn one() -> Differentiable<A> {
|
fn one() -> Differentiable<A> {
|
||||||
Differentiable::Scalar(Scalar::one())
|
Differentiable {
|
||||||
|
contents: DifferentiableContents::Scalar(Scalar::one()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Clone for DifferentiableContents<A>
|
||||||
|
where
|
||||||
|
A: Clone,
|
||||||
|
{
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
match self {
|
||||||
|
Self::Scalar(arg0) => Self::Scalar(arg0.clone()),
|
||||||
|
Self::Vector(arg0, rank) => Self::Vector(arg0.clone(), *rank),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,27 +55,32 @@ where
|
|||||||
A: Clone,
|
A: Clone,
|
||||||
{
|
{
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
match self {
|
Differentiable {
|
||||||
Self::Scalar(arg0) => Self::Scalar(arg0.clone()),
|
contents: self.contents.clone(),
|
||||||
Self::Vector(arg0) => Self::Vector(arg0.clone()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum Differentiable<A> {
|
enum DifferentiableContents<A> {
|
||||||
Scalar(Scalar<A>),
|
Scalar(Scalar<A>),
|
||||||
Vector(Vec<Differentiable<A>>),
|
// Contains the rank.
|
||||||
|
Vector(Vec<Differentiable<A>>, usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A> Display for Differentiable<A>
|
#[derive(Debug)]
|
||||||
|
pub struct Differentiable<A> {
|
||||||
|
contents: DifferentiableContents<A>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Display for DifferentiableContents<A>
|
||||||
where
|
where
|
||||||
A: Display,
|
A: Display,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
match self {
|
match self {
|
||||||
Differentiable::Scalar(s) => f.write_fmt(format_args!("{}", s)),
|
DifferentiableContents::Scalar(s) => f.write_fmt(format_args!("{}", s)),
|
||||||
Differentiable::Vector(v) => {
|
DifferentiableContents::Vector(v, _rank) => {
|
||||||
f.write_char('[')?;
|
f.write_char('[')?;
|
||||||
for v in v.iter() {
|
for v in v.iter() {
|
||||||
f.write_fmt(format_args!("{}", v))?;
|
f.write_fmt(format_args!("{}", v))?;
|
||||||
@@ -71,106 +92,196 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<A> Differentiable<A> {
|
impl<A> Display for Differentiable<A>
|
||||||
pub fn map<B, F>(&self, f: &mut F) -> Differentiable<B>
|
where
|
||||||
|
A: Display,
|
||||||
|
{
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_fmt(format_args!("{}", self.contents))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> DifferentiableContents<A> {
|
||||||
|
fn map<B, F>(&self, f: &mut F) -> DifferentiableContents<B>
|
||||||
where
|
where
|
||||||
F: FnMut(Scalar<A>) -> Scalar<B>,
|
F: FnMut(Scalar<A>) -> Scalar<B>,
|
||||||
A: Clone,
|
A: Clone,
|
||||||
{
|
{
|
||||||
match self {
|
match self {
|
||||||
Differentiable::Scalar(a) => Differentiable::Scalar(f(a.clone())),
|
DifferentiableContents::Scalar(a) => DifferentiableContents::Scalar(f(a.clone())),
|
||||||
Differentiable::Vector(slice) => {
|
DifferentiableContents::Vector(slice, rank) => {
|
||||||
Differentiable::Vector(slice.iter().map(|x| x.map(f)).collect())
|
DifferentiableContents::Vector(slice.iter().map(|x| x.map(f)).collect(), *rank)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn map2<B, C, F>(&self, other: &DifferentiableContents<B>, f: &F) -> DifferentiableContents<C>
|
||||||
|
where
|
||||||
|
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
|
||||||
|
A: Clone,
|
||||||
|
B: Clone,
|
||||||
|
{
|
||||||
|
match (self, other) {
|
||||||
|
(DifferentiableContents::Scalar(a), DifferentiableContents::Scalar(b)) => {
|
||||||
|
DifferentiableContents::Scalar(f(a, b))
|
||||||
|
}
|
||||||
|
(
|
||||||
|
DifferentiableContents::Vector(slice_a, rank_a),
|
||||||
|
DifferentiableContents::Vector(slice_b, rank_b),
|
||||||
|
) => {
|
||||||
|
if rank_a != rank_b {
|
||||||
|
panic!("Unexpectedly different ranks in map2");
|
||||||
|
}
|
||||||
|
DifferentiableContents::Vector(
|
||||||
|
slice_a
|
||||||
|
.iter()
|
||||||
|
.zip(slice_b.iter())
|
||||||
|
.map(|(a, b)| a.map2(b, f))
|
||||||
|
.collect(),
|
||||||
|
*rank_a,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
_ => panic!("Wrong shapes!"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn of_slice<T>(input: T) -> DifferentiableContents<A>
|
||||||
|
where
|
||||||
|
A: Clone,
|
||||||
|
T: AsRef<[A]>,
|
||||||
|
{
|
||||||
|
DifferentiableContents::Vector(
|
||||||
|
input
|
||||||
|
.as_ref()
|
||||||
|
.iter()
|
||||||
|
.map(|v| Differentiable {
|
||||||
|
contents: DifferentiableContents::Scalar(Scalar::Number((*v).clone(), None)),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rank(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
DifferentiableContents::Scalar(_) => 0,
|
||||||
|
DifferentiableContents::Vector(_, rank) => *rank,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Differentiable<A> {
|
||||||
|
pub fn map<B, F>(&self, f: &mut F) -> Differentiable<B>
|
||||||
|
where
|
||||||
|
A: Clone,
|
||||||
|
F: FnMut(Scalar<A>) -> Scalar<B>,
|
||||||
|
{
|
||||||
|
Differentiable {
|
||||||
|
contents: self.contents.map(f),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn map2<B, C, F>(&self, other: &Differentiable<B>, f: &F) -> Differentiable<C>
|
pub fn map2<B, C, F>(&self, other: &Differentiable<B>, f: &F) -> Differentiable<C>
|
||||||
where
|
where
|
||||||
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
|
F: Fn(&Scalar<A>, &Scalar<B>) -> Scalar<C>,
|
||||||
A: Clone,
|
A: Clone,
|
||||||
B: Clone,
|
B: Clone,
|
||||||
{
|
{
|
||||||
match (self, other) {
|
Differentiable {
|
||||||
(Differentiable::Scalar(a), Differentiable::Scalar(b)) => {
|
contents: self.contents.map2(&other.contents, f),
|
||||||
Differentiable::Scalar(f(a, b))
|
|
||||||
}
|
}
|
||||||
(Differentiable::Vector(slice_a), Differentiable::Vector(slice_b)) => {
|
|
||||||
Differentiable::Vector(
|
|
||||||
slice_a
|
|
||||||
.iter()
|
|
||||||
.zip(slice_b.iter())
|
|
||||||
.map(|(a, b)| a.map2(b, f))
|
|
||||||
.collect(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
_ => panic!("Wrong shapes!"),
|
|
||||||
|
pub fn attach_rank<const RANK: usize>(
|
||||||
|
self: Differentiable<A>,
|
||||||
|
) -> Option<RankedDifferentiable<A, RANK>> {
|
||||||
|
if self.contents.rank() == RANK {
|
||||||
|
Some(RankedDifferentiable { contents: self })
|
||||||
|
} else {
|
||||||
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn of_scalar(s: Scalar<A>) -> Differentiable<A> {
|
||||||
|
Differentiable {
|
||||||
|
contents: DifferentiableContents::Scalar(s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> DifferentiableContents<A> {
|
||||||
|
fn into_scalar(self) -> Scalar<A> {
|
||||||
|
match self {
|
||||||
|
DifferentiableContents::Scalar(s) => s,
|
||||||
|
DifferentiableContents::Vector(_, _) => panic!("not a scalar"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_vector(self) -> Vec<Differentiable<A>> {
|
||||||
|
match self {
|
||||||
|
DifferentiableContents::Scalar(_) => panic!("not a vector"),
|
||||||
|
DifferentiableContents::Vector(v, _) => v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn borrow_scalar(&self) -> &Scalar<A> {
|
||||||
|
match self {
|
||||||
|
DifferentiableContents::Scalar(s) => s,
|
||||||
|
DifferentiableContents::Vector(_, _) => panic!("not a scalar"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
|
||||||
|
match self {
|
||||||
|
DifferentiableContents::Scalar(_) => panic!("not a vector"),
|
||||||
|
DifferentiableContents::Vector(v, _) => v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> Differentiable<A> {
|
||||||
|
pub fn into_scalar(self) -> Scalar<A> {
|
||||||
|
self.contents.into_scalar()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_vector(self) -> Vec<Differentiable<A>> {
|
||||||
|
self.contents.into_vector()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn borrow_scalar(&self) -> &Scalar<A> {
|
||||||
|
self.contents.borrow_scalar()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
|
||||||
|
self.contents.borrow_vector()
|
||||||
|
}
|
||||||
|
|
||||||
fn of_slice<T>(input: T) -> Differentiable<A>
|
fn of_slice<T>(input: T) -> Differentiable<A>
|
||||||
where
|
where
|
||||||
A: Clone,
|
A: Clone,
|
||||||
T: AsRef<[A]>,
|
T: AsRef<[A]>,
|
||||||
{
|
{
|
||||||
Differentiable::Vector(
|
Differentiable {
|
||||||
input
|
contents: DifferentiableContents::of_slice(input),
|
||||||
.as_ref()
|
}
|
||||||
.iter()
|
}
|
||||||
.map(|v| Differentiable::Scalar(Scalar::Number((*v).clone(), None)))
|
|
||||||
.collect(),
|
pub fn of_vec(input: Vec<Differentiable<A>>) -> Differentiable<A> {
|
||||||
)
|
if input.is_empty() {
|
||||||
|
panic!("Can't make an empty tensor");
|
||||||
|
}
|
||||||
|
let rank = input[0].rank();
|
||||||
|
Differentiable {
|
||||||
|
contents: DifferentiableContents::Vector(input, 1 + rank),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
match self {
|
self.contents.rank()
|
||||||
Differentiable::Scalar(_) => 0,
|
|
||||||
Differentiable::Vector(v) => v[0].rank() + 1,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn attach_rank<const RANK: usize>(
|
impl<A> DifferentiableContents<A>
|
||||||
self: Differentiable<A>,
|
|
||||||
) -> Option<RankedDifferentiable<A, RANK>> {
|
|
||||||
if self.rank() == RANK {
|
|
||||||
Some(RankedDifferentiable { contents: self })
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<A> Differentiable<A> {
|
|
||||||
pub fn into_scalar(self) -> Scalar<A> {
|
|
||||||
match self {
|
|
||||||
Differentiable::Scalar(s) => s,
|
|
||||||
Differentiable::Vector(_) => panic!("not a scalar"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_vector(self) -> Vec<Differentiable<A>> {
|
|
||||||
match self {
|
|
||||||
Differentiable::Scalar(_) => panic!("not a vector"),
|
|
||||||
Differentiable::Vector(v) => v,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn borrow_scalar(&self) -> &Scalar<A> {
|
|
||||||
match self {
|
|
||||||
Differentiable::Scalar(s) => s,
|
|
||||||
Differentiable::Vector(_) => panic!("not a scalar"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn borrow_vector(&self) -> &Vec<Differentiable<A>> {
|
|
||||||
match self {
|
|
||||||
Differentiable::Scalar(_) => panic!("not a vector"),
|
|
||||||
Differentiable::Vector(v) => v,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<A> Differentiable<A>
|
|
||||||
where
|
where
|
||||||
A: Clone
|
A: Clone
|
||||||
+ Eq
|
+ Eq
|
||||||
@@ -185,17 +296,19 @@ where
|
|||||||
{
|
{
|
||||||
fn accumulate_gradients_vec(v: &[Differentiable<A>], acc: &mut HashMap<Scalar<A>, A>) {
|
fn accumulate_gradients_vec(v: &[Differentiable<A>], acc: &mut HashMap<Scalar<A>, A>) {
|
||||||
for v in v.iter().rev() {
|
for v in v.iter().rev() {
|
||||||
v.accumulate_gradients(acc);
|
v.contents.accumulate_gradients(acc);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) {
|
fn accumulate_gradients(&self, acc: &mut HashMap<Scalar<A>, A>) {
|
||||||
match self {
|
match self {
|
||||||
Differentiable::Scalar(y) => {
|
DifferentiableContents::Scalar(y) => {
|
||||||
let k = y.clone_link();
|
let k = y.clone_link();
|
||||||
k.invoke(y, A::one(), acc);
|
k.invoke(y, A::one(), acc);
|
||||||
}
|
}
|
||||||
Differentiable::Vector(y) => Differentiable::accumulate_gradients_vec(y, acc),
|
DifferentiableContents::Vector(y, _rank) => {
|
||||||
|
DifferentiableContents::accumulate_gradients_vec(y, acc)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,15 +344,12 @@ where
|
|||||||
|
|
||||||
impl<A> RankedDifferentiable<A, 0> {
|
impl<A> RankedDifferentiable<A, 0> {
|
||||||
pub fn to_scalar(self) -> Scalar<A> {
|
pub fn to_scalar(self) -> Scalar<A> {
|
||||||
match self.contents {
|
self.contents.contents.into_scalar()
|
||||||
Differentiable::Scalar(s) => s,
|
|
||||||
Differentiable::Vector(_) => panic!("not a scalar despite teq that we're a scalar"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn of_scalar(s: Scalar<A>) -> RankedDifferentiable<A, 0> {
|
pub fn of_scalar(s: Scalar<A>) -> RankedDifferentiable<A, 0> {
|
||||||
RankedDifferentiable {
|
RankedDifferentiable {
|
||||||
contents: Differentiable::Scalar(s),
|
contents: Differentiable::of_scalar(s),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -251,7 +361,9 @@ impl<A> RankedDifferentiable<A, 1> {
|
|||||||
T: AsRef<[A]>,
|
T: AsRef<[A]>,
|
||||||
{
|
{
|
||||||
RankedDifferentiable {
|
RankedDifferentiable {
|
||||||
contents: Differentiable::of_slice(input),
|
contents: Differentiable {
|
||||||
|
contents: DifferentiableContents::of_slice(input),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -267,7 +379,7 @@ impl<A> RankedDifferentiable<A, 2> {
|
|||||||
.map(|x| Differentiable::of_slice(x))
|
.map(|x| Differentiable::of_slice(x))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
RankedDifferentiable {
|
RankedDifferentiable {
|
||||||
contents: Differentiable::Vector(v),
|
contents: Differentiable::of_vec(v),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -285,7 +397,7 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
|
|||||||
s: Vec<RankedDifferentiable<A, RANK>>,
|
s: Vec<RankedDifferentiable<A, RANK>>,
|
||||||
) -> RankedDifferentiable<A, { RANK + 1 }> {
|
) -> RankedDifferentiable<A, { RANK + 1 }> {
|
||||||
RankedDifferentiable {
|
RankedDifferentiable {
|
||||||
contents: Differentiable::Vector(s.into_iter().map(|v| v.contents).collect()),
|
contents: Differentiable::of_vec(s.into_iter().map(|v| v.contents).collect()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -320,13 +432,11 @@ impl<A, const RANK: usize> RankedDifferentiable<A, RANK> {
|
|||||||
pub fn to_vector(
|
pub fn to_vector(
|
||||||
self: RankedDifferentiable<A, RANK>,
|
self: RankedDifferentiable<A, RANK>,
|
||||||
) -> Vec<RankedDifferentiable<A, { RANK - 1 }>> {
|
) -> Vec<RankedDifferentiable<A, { RANK - 1 }>> {
|
||||||
match self.contents {
|
self.contents
|
||||||
Differentiable::Scalar(_) => panic!("not a scalar"),
|
.into_vector()
|
||||||
Differentiable::Vector(v) => v
|
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|v| RankedDifferentiable { contents: v })
|
.map(|v| RankedDifferentiable { contents: v })
|
||||||
.collect(),
|
.collect()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -357,7 +467,7 @@ where
|
|||||||
})
|
})
|
||||||
});
|
});
|
||||||
let after_f = f(&wrt);
|
let after_f = f(&wrt);
|
||||||
Differentiable::grad_once(after_f.contents, wrt)
|
DifferentiableContents::grad_once(after_f.contents.contents, wrt)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -369,21 +479,18 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
fn extract_scalar<'a, A>(d: &'a Differentiable<A>) -> &'a A {
|
fn extract_scalar<'a, A>(d: &'a Differentiable<A>) -> &'a A {
|
||||||
match d {
|
d.borrow_scalar().real_part()
|
||||||
Differentiable::Scalar(a) => &(a.real_part()),
|
|
||||||
Differentiable::Vector(_) => panic!("not a scalar"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_map() {
|
fn test_map() {
|
||||||
let v = Differentiable::Vector(
|
let v = Differentiable::of_vec(
|
||||||
vec![
|
vec![
|
||||||
Differentiable::Scalar(Scalar::Number(
|
Differentiable::of_scalar(Scalar::Number(
|
||||||
NotNan::new(3.0).expect("3 is not NaN"),
|
NotNan::new(3.0).expect("3 is not NaN"),
|
||||||
Some(0usize),
|
Some(0usize),
|
||||||
)),
|
)),
|
||||||
Differentiable::Scalar(Scalar::Number(
|
Differentiable::of_scalar(Scalar::Number(
|
||||||
NotNan::new(4.0).expect("4 is not NaN"),
|
NotNan::new(4.0).expect("4 is not NaN"),
|
||||||
Some(1usize),
|
Some(1usize),
|
||||||
)),
|
)),
|
||||||
@@ -395,13 +502,11 @@ mod tests {
|
|||||||
Scalar::Dual(_, _) => panic!("Not hit"),
|
Scalar::Dual(_, _) => panic!("Not hit"),
|
||||||
});
|
});
|
||||||
|
|
||||||
let v = match mapped {
|
let v = mapped
|
||||||
Differentiable::Scalar(_) => panic!("Not a scalar"),
|
.into_vector()
|
||||||
Differentiable::Vector(v) => v
|
|
||||||
.iter()
|
.iter()
|
||||||
.map(|d| extract_scalar(d).clone())
|
.map(|d| extract_scalar(d).clone())
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>();
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(v, [4.0, 5.0]);
|
assert_eq!(v, [4.0, 5.0]);
|
||||||
}
|
}
|
||||||
|
@@ -126,7 +126,7 @@ where
|
|||||||
let dotted = RankedDifferentiable::of_scalar(
|
let dotted = RankedDifferentiable::of_scalar(
|
||||||
dot_unranked(
|
dot_unranked(
|
||||||
left_arg.to_unranked_borrow(),
|
left_arg.to_unranked_borrow(),
|
||||||
&Differentiable::Vector(theta.to_vec()),
|
&Differentiable::of_vec(theta.to_vec()),
|
||||||
)
|
)
|
||||||
.into_vector()
|
.into_vector()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -180,7 +180,7 @@ where
|
|||||||
);
|
);
|
||||||
dot_unranked(
|
dot_unranked(
|
||||||
x_powers.to_unranked_borrow(),
|
x_powers.to_unranked_borrow(),
|
||||||
&Differentiable::Vector(theta.to_vec()),
|
&Differentiable::of_vec(theta.to_vec()),
|
||||||
)
|
)
|
||||||
.attach_rank::<1>()
|
.attach_rank::<1>()
|
||||||
.expect("wanted a tensor1")
|
.expect("wanted a tensor1")
|
||||||
|
@@ -106,7 +106,7 @@ fn main() {
|
|||||||
},
|
},
|
||||||
[
|
[
|
||||||
RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(),
|
RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(),
|
||||||
Differentiable::Scalar(Scalar::zero()),
|
Differentiable::of_scalar(Scalar::zero()),
|
||||||
],
|
],
|
||||||
hyper.iterations,
|
hyper.iterations,
|
||||||
)
|
)
|
||||||
@@ -168,7 +168,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn grad_example() {
|
fn grad_example() {
|
||||||
let input_vec = [Differentiable::Scalar(Scalar::make(
|
let input_vec = [Differentiable::of_scalar(Scalar::make(
|
||||||
NotNan::new(27.0).expect("not nan"),
|
NotNan::new(27.0).expect("not nan"),
|
||||||
))];
|
))];
|
||||||
|
|
||||||
@@ -362,7 +362,7 @@ mod tests {
|
|||||||
},
|
},
|
||||||
[
|
[
|
||||||
RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(),
|
RankedDifferentiable::of_slice([NotNan::zero(), NotNan::zero()]).to_unranked(),
|
||||||
Differentiable::Scalar(Scalar::zero()),
|
Differentiable::of_scalar(Scalar::zero()),
|
||||||
],
|
],
|
||||||
hyper.iterations,
|
hyper.iterations,
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user