This commit is contained in:
Patrick Stevens
2023-06-17 18:03:24 +01:00
committed by GitHub
parent 5bb1bddf83
commit 095a8af7f2

View File

@@ -185,6 +185,27 @@ where
.map(&mut rectify)
}
pub fn k_relu<A, Tag>(
t: &RankedDifferentiableTagged<A, Tag, 1>,
theta: &[Differentiable<A>],
) -> Differentiable<A>
where
Tag: Clone,
A: NumLike + PartialOrd + Default,
{
assert!(theta.len() < 2, "Needed at least 2 parameters for k_relu");
let once = relu(
t,
&theta[0].clone().attach_rank::<2>().unwrap(),
&theta[1].clone().attach_rank::<1>().unwrap(),
);
if theta.len() == 2 {
once
} else {
k_relu(&once.attach_rank().unwrap(), &theta[2..])
}
}
#[cfg(test)]
mod tests {
use crate::auto_diff::{Differentiable, RankedDifferentiable};