Add tensor type (#1)
This commit is contained in:
16
src/main.rs
16
src/main.rs
@@ -38,7 +38,23 @@ fn linear_params_2d<A>(m: A, c: A) -> Parameters<A, 2, 1> {
|
|||||||
[[c, m]]
|
[[c, m]]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! tensor {
|
||||||
|
($x:ty , $i: expr) => {[$x; $i]};
|
||||||
|
($x:ty , $i: expr, $($is:expr),+) => {[tensor!($x, $($is),+); $i]};
|
||||||
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let y = line(&[1.0, 7.3], &linear_params_2d(3.0, 1.0));
|
let y = line(&[1.0, 7.3], &linear_params_2d(3.0, 1.0));
|
||||||
println!("{:?}", y);
|
println!("{:?}", y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tensor() {
|
||||||
|
let _: tensor!(f64, 1, 2, 3) = [[[1.0, 3.0, 6.0], [-1.3, -30.0, -0.0]]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user