Package ai.djl.nn.core
Class SparseMax
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.core.SparseMax
- All Implemented Interfaces:
Block
SparseMax
contains a generic implementation of sparsemax function the definition of
SparseMax can be referred to https://arxiv.org/pdf/1602.02068.pdf. SparseMax
is a simpler
implementation of sparseMax function, where we set K as a hyperParameter(default 3). We only do
softmax on those max-K data, and we set all the other value as 0.-
Field Summary
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Constructor Summary
ConstructorsConstructorDescriptionCreates a sparseMax activation function for the last axis and 3 elements.SparseMax
(int axis) Creates a sparseMax activation function along a given axis for 3 elements.SparseMax
(int axis, int topK) Creates a sparseMax activation function along a given axis and number of elements. -
Method Summary
Modifier and TypeMethodDescriptionprotected NDList
forwardInternal
(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Shape[]
getOutputShapes
(Shape[] inputShapes) Returns the expected output shapes of the block for the specified input shapes.Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Constructor Details
-
SparseMax
public SparseMax()Creates a sparseMax activation function for the last axis and 3 elements. -
SparseMax
public SparseMax(int axis) Creates a sparseMax activation function along a given axis for 3 elements.- Parameters:
axis
- the axis to do sparseMax for
-
SparseMax
public SparseMax(int axis, int topK) Creates a sparseMax activation function along a given axis and number of elements.- Parameters:
axis
- the axis to do sparseMax fortopK
- hyperParameter K
-
-
Method Details
-
getOutputShapes
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-