k-relu (#28)
This commit is contained in:
@@ -185,6 +185,27 @@ where
|
||||
.map(&mut rectify)
|
||||
}
|
||||
|
||||
pub fn k_relu<A, Tag>(
|
||||
t: &RankedDifferentiableTagged<A, Tag, 1>,
|
||||
theta: &[Differentiable<A>],
|
||||
) -> Differentiable<A>
|
||||
where
|
||||
Tag: Clone,
|
||||
A: NumLike + PartialOrd + Default,
|
||||
{
|
||||
assert!(theta.len() < 2, "Needed at least 2 parameters for k_relu");
|
||||
let once = relu(
|
||||
t,
|
||||
&theta[0].clone().attach_rank::<2>().unwrap(),
|
||||
&theta[1].clone().attach_rank::<1>().unwrap(),
|
||||
);
|
||||
if theta.len() == 2 {
|
||||
once
|
||||
} else {
|
||||
k_relu(&once.attach_rank().unwrap(), &theta[2..])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::auto_diff::{Differentiable, RankedDifferentiable};
|
||||
|
Reference in New Issue
Block a user