Add tensor type (#1)
This commit is contained in:
16
src/main.rs
16
src/main.rs
@@ -38,7 +38,23 @@ fn linear_params_2d<A>(m: A, c: A) -> Parameters<A, 2, 1> {
|
||||
[[c, m]]
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tensor {
|
||||
($x:ty , $i: expr) => {[$x; $i]};
|
||||
($x:ty , $i: expr, $($is:expr),+) => {[tensor!($x, $($is),+); $i]};
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let y = line(&[1.0, 7.3], &linear_params_2d(3.0, 1.0));
|
||||
println!("{:?}", y);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tensor() {
|
||||
let _: tensor!(f64, 1, 2, 3) = [[[1.0, 3.0, 6.0], [-1.3, -30.0, -0.0]]];
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user