public class Prelu extends AbstractBlock
Leaky ReLUs attempt to fix the 'dying ReLU' problem by allowing a small slope when the input is negative and has a slope of one when input is positive. This is defined by \(y= x \gt 0 ? x : slope * x\).
Parametric ReLU is a Leaky ReLU in which the slope is learnt during training.
children, inputNames, inputShapes, parameters, version
Constructor and Description |
---|
Prelu()
Creates a Parametric ReLU Block.
|
Modifier and Type | Method and Description |
---|---|
protected NDList |
forwardInternal(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper for
Block.forward(ParameterStore, NDList, boolean, PairList) after
initialization. |
Shape[] |
getOutputShapes(Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.
|
void |
loadMetadata(byte loadVersion,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
static NDList |
prelu(NDArray input,
NDArray alpha)
Applies a Prelu activation on the input
NDArray . |
addChildBlock, addParameter, beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, validateLayout
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Block.forward(ParameterStore, NDList, boolean, PairList)
after
initialization.forwardInternal
in class AbstractBlock
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameterspublic Shape[] getOutputShapes(Shape[] inputs)
inputs
- the shapes of the inputspublic void loadMetadata(byte loadVersion, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream)
or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException
. After that it restores the input shapes.
loadMetadata
in class AbstractBlock
loadVersion
- the version used for loading this metadata.is
- the input stream we are loading fromjava.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong format