Add tensor type (#1)

This commit is contained in:
Patrick Stevens
2023-03-18 20:41:14 +00:00
committed by GitHub
parent e2e3eed1b9
commit 0a199b0065

View File

@@ -38,7 +38,23 @@ fn linear_params_2d<A>(m: A, c: A) -> Parameters<A, 2, 1> {
[[c, m]]
}
#[macro_export]
macro_rules! tensor {
($x:ty , $i: expr) => {[$x; $i]};
($x:ty , $i: expr, $($is:expr),+) => {[tensor!($x, $($is),+); $i]};
}
fn main() {
let y = line(&[1.0, 7.3], &linear_params_2d(3.0, 1.0));
println!("{:?}", y);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor() {
let _: tensor!(f64, 1, 2, 3) = [[[1.0, 3.0, 6.0], [-1.3, -30.0, -0.0]]];
}
}