Class SparseMax

All Implemented Interfaces:
Block

public class SparseMax extends AbstractBlock
SparseMax contains a generic implementation of sparsemax function the definition of SparseMax can be referred to https://arxiv.org/pdf/1602.02068.pdf. SparseMax is a simpler implementation of sparseMax function, where we set K as a hyperParameter(default 3). We only do softmax on those max-K data, and we set all the other value as 0.
  • Constructor Details

    • SparseMax

      public SparseMax()
      Creates a sparseMax activation function for the last axis and 3 elements.
    • SparseMax

      public SparseMax(int axis)
      Creates a sparseMax activation function along a given axis for 3 elements.
      Parameters:
      axis - the axis to do sparseMax for
    • SparseMax

      public SparseMax(int axis, int topK)
      Creates a sparseMax activation function along a given axis and number of elements.
      Parameters:
      axis - the axis to do sparseMax for
      topK - hyperParameter K
  • Method Details

    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputShapes)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputShapes - the shapes of the inputs
      Returns:
      the expected output shapes of the block
    • forwardInternal

      protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass