Class SparseMax

  • All Implemented Interfaces:
    Block

    public class SparseMax
    extends AbstractBlock
    SparseMax contains a generic implementation of sparsemax function the definition of SparseMax can be referred to https://arxiv.org/pdf/1602.02068.pdf. SparseMax is a simpler implementation of sparseMax function, where we set K as a hyperParameter(default 3). We only do softmax on those max-K data, and we set all the other value as 0.
    • Constructor Detail

      • SparseMax

        public SparseMax()
        Creates a sparseMax activation function for the last axis and 3 elements.
      • SparseMax

        public SparseMax​(int axis)
        Creates a sparseMax activation function along a given axis for 3 elements.
        Parameters:
        axis - the axis to do sparseMax for
      • SparseMax

        public SparseMax​(int axis,
                         int topK)
        Creates a sparseMax activation function along a given axis and number of elements.
        Parameters:
        axis - the axis to do sparseMax for
        topK - hyperParameter K
    • Method Detail

      • getOutputShapes

        public Shape[] getOutputShapes​(Shape[] inputShapes)
        Returns the expected output shapes of the block for the specified input shapes.
        Parameters:
        inputShapes - the shapes of the inputs
        Returns:
        the expected output shapes of the block