diff --git a/little_learner/src/ext.rs b/little_learner/src/ext.rs index c9a9207..763ed9b 100644 --- a/little_learner/src/ext.rs +++ b/little_learner/src/ext.rs @@ -185,6 +185,27 @@ where .map(&mut rectify) } +pub fn k_relu( + t: &RankedDifferentiableTagged, + theta: &[Differentiable], +) -> Differentiable +where + Tag: Clone, + A: NumLike + PartialOrd + Default, +{ + assert!(theta.len() < 2, "Needed at least 2 parameters for k_relu"); + let once = relu( + t, + &theta[0].clone().attach_rank::<2>().unwrap(), + &theta[1].clone().attach_rank::<1>().unwrap(), + ); + if theta.len() == 2 { + once + } else { + k_relu(&once.attach_rank().unwrap(), &theta[2..]) + } +} + #[cfg(test)] mod tests { use crate::auto_diff::{Differentiable, RankedDifferentiable};