Package ai.djl.nn.core
Class SparseMax
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.core.SparseMax
-
- All Implemented Interfaces:
Block
public class SparseMax extends AbstractBlock
SparseMax
contains a generic implementation of sparsemax function the definition of SparseMax can be referred to https://arxiv.org/pdf/1602.02068.pdf.SparseMax
is a simpler implementation of sparseMax function, where we set K as a hyperParameter(default 3). We only do softmax on those max-K data, and we set all the other value as 0.
-
-
Field Summary
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, version
-
-
Constructor Summary
Constructors Constructor Description SparseMax()
Creates a sparseMax activation function for the last axis and 3 elements.SparseMax(int axis)
Creates a sparseMax activation function along a given axis for 3 elements.SparseMax(int axis, int topK)
Creates a sparseMax activation function along a given axis and number of elements.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description protected NDList
forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Shape[]
getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters
-
-
-
-
Constructor Detail
-
SparseMax
public SparseMax()
Creates a sparseMax activation function for the last axis and 3 elements.
-
SparseMax
public SparseMax(int axis)
Creates a sparseMax activation function along a given axis for 3 elements.- Parameters:
axis
- the axis to do sparseMax for
-
SparseMax
public SparseMax(int axis, int topK)
Creates a sparseMax activation function along a given axis and number of elements.- Parameters:
axis
- the axis to do sparseMax fortopK
- hyperParameter K
-
-
Method Detail
-
getOutputShapes
public Shape[] getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
-