diff --git a/little_learner/src/ext.rs b/little_learner/src/ext.rs
index c9a9207..763ed9b 100644
--- a/little_learner/src/ext.rs
+++ b/little_learner/src/ext.rs
@@ -185,6 +185,27 @@ where
.map(&mut rectify)
}
+pub fn k_relu(
+ t: &RankedDifferentiableTagged,
+ theta: &[Differentiable],
+) -> Differentiable
+where
+ Tag: Clone,
+ A: NumLike + PartialOrd + Default,
+{
+ assert!(theta.len() < 2, "Needed at least 2 parameters for k_relu");
+ let once = relu(
+ t,
+ &theta[0].clone().attach_rank::<2>().unwrap(),
+ &theta[1].clone().attach_rank::<1>().unwrap(),
+ );
+ if theta.len() == 2 {
+ once
+ } else {
+ k_relu(&once.attach_rank().unwrap(), &theta[2..])
+ }
+}
+
#[cfg(test)]
mod tests {
use crate::auto_diff::{Differentiable, RankedDifferentiable};