k-relu (#28)
This commit is contained in:
@@ -185,6 +185,27 @@ where
|
|||||||
.map(&mut rectify)
|
.map(&mut rectify)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn k_relu<A, Tag>(
|
||||||
|
t: &RankedDifferentiableTagged<A, Tag, 1>,
|
||||||
|
theta: &[Differentiable<A>],
|
||||||
|
) -> Differentiable<A>
|
||||||
|
where
|
||||||
|
Tag: Clone,
|
||||||
|
A: NumLike + PartialOrd + Default,
|
||||||
|
{
|
||||||
|
assert!(theta.len() < 2, "Needed at least 2 parameters for k_relu");
|
||||||
|
let once = relu(
|
||||||
|
t,
|
||||||
|
&theta[0].clone().attach_rank::<2>().unwrap(),
|
||||||
|
&theta[1].clone().attach_rank::<1>().unwrap(),
|
||||||
|
);
|
||||||
|
if theta.len() == 2 {
|
||||||
|
once
|
||||||
|
} else {
|
||||||
|
k_relu(&once.attach_rank().unwrap(), &theta[2..])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::auto_diff::{Differentiable, RankedDifferentiable};
|
use crate::auto_diff::{Differentiable, RankedDifferentiable};
|
||||||
|
Reference in New Issue
Block a user