Class TabNet

  • All Implemented Interfaces:
    ai.djl.nn.Block

    public final class TabNet
    extends ai.djl.nn.AbstractBlock
    TabNet contains a generic implementation of TabNet adapted from https://towardsdatascience.com/implementing-tabnet-in-pytorch-fc977c383279 (Original author Samrat Thapa)

    TabNet is a neural architecture for tabular dataset developed by the research team at Google Cloud AI. It was able to achieve state_of_the_art results on several datasets in both regression and classification problems. Another desirable feature of TabNet is interpretability. Contrary to most of deep learning, where the neural networks act like black boxes, we can interpret which features the models selects in case of TabNet.

    see https://arxiv.org/pdf/1908.07442.pdf for more information about TabNet

    • Nested Class Summary

      Nested Classes 
      Modifier and Type Class Description
      static class  TabNet.AttentionTransformer
      AttentionTransformer is where the tabNet models learn the relationship between relevant features, and decides which features to pass on to the feature transformer of the current decision step.
      static class  TabNet.Builder
      The Builder to construct a TabNet object.
      static class  TabNet.DecisionStep
      DecisionStep is just combining featureTransformer and attentionTransformer together.
    • Field Summary

      • Fields inherited from class ai.djl.nn.AbstractBlock

        children, parameters
      • Fields inherited from class ai.djl.nn.AbstractBaseBlock

        inputNames, inputShapes, outputDataTypes, version
    • Method Summary

      All Methods Static Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      static TabNet.Builder builder()
      Creates a builder to build a TabNet.
      static ai.djl.nn.Block featureTransformer​(java.util.List<ai.djl.nn.Block> sharedBlocks, int outDim, int numIndependent, int virtualBatchSize, float batchNormMomentum)
      Creates a featureTransformer Block.
      protected ai.djl.ndarray.NDList forwardInternal​(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
      ai.djl.ndarray.types.Shape[] getOutputShapes​(ai.djl.ndarray.types.Shape[] inputShapes)
      static ai.djl.nn.Block gluBlock​(ai.djl.nn.Block sharedBlock, int outDim, int virtualBatchSize, float batchNormMomentum)
      Creates a FC-BN-GLU block used in tabNet.
      protected void initializeChildBlocks​(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
      static ai.djl.ndarray.NDArray tabNetGLU​(ai.djl.ndarray.NDArray array, int units)
      Applies tabNetGLU activation(which is mostly used in tabNet) on the input NDArray.
      static ai.djl.ndarray.NDList tabNetGLU​(ai.djl.ndarray.NDList arrays, int units)
      Applies tabNetGLU activation(which is mostly used in tabNet) on the input singleton NDList.
      static ai.djl.nn.Block tabNetGLUBlock​(int units)
      Creates a LambdaBlock that applies the tabNetGLU(NDArray, int) activation function in its forward function.
      • Methods inherited from class ai.djl.nn.AbstractBlock

        addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
      • Methods inherited from class ai.djl.nn.AbstractBaseBlock

        beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
      • Methods inherited from interface ai.djl.nn.Block

        forward, freezeParameters, getOutputShapes
    • Method Detail

      • tabNetGLU

        public static ai.djl.ndarray.NDArray tabNetGLU​(ai.djl.ndarray.NDArray array,
                                                       int units)
        Applies tabNetGLU activation(which is mostly used in tabNet) on the input NDArray.
        Parameters:
        array - the input NDArray
        units - the half number of the resultant features
        Returns:
        the NDArray after applying tabNetGLU function
      • tabNetGLU

        public static ai.djl.ndarray.NDList tabNetGLU​(ai.djl.ndarray.NDList arrays,
                                                      int units)
        Applies tabNetGLU activation(which is mostly used in tabNet) on the input singleton NDList.
        Parameters:
        arrays - the input singleton NDList
        units - the half number of the resultant features
        Returns:
        the singleton NDList after applying tabNetGLU function
      • tabNetGLUBlock

        public static ai.djl.nn.Block tabNetGLUBlock​(int units)
        Creates a LambdaBlock that applies the tabNetGLU(NDArray, int) activation function in its forward function.
        Parameters:
        units - the half number of feature
        Returns:
        LambdaBlock that applies the tabNetGLU(NDArray, int) activation function
      • gluBlock

        public static ai.djl.nn.Block gluBlock​(ai.djl.nn.Block sharedBlock,
                                               int outDim,
                                               int virtualBatchSize,
                                               float batchNormMomentum)
        Creates a FC-BN-GLU block used in tabNet. In order to do GLU, we double the dimension of the input features to the GLU using a fc layer.
        Parameters:
        sharedBlock - the shared fully connected layer
        outDim - the output feature dimension
        virtualBatchSize - the virtualBatchSize
        batchNormMomentum - the momentum used for ghost batchNorm layer
        Returns:
        a FC-BN-GLU block
      • featureTransformer

        public static ai.djl.nn.Block featureTransformer​(java.util.List<ai.djl.nn.Block> sharedBlocks,
                                                         int outDim,
                                                         int numIndependent,
                                                         int virtualBatchSize,
                                                         float batchNormMomentum)
        Creates a featureTransformer Block. The feature transformer is where all the selected features are processed to generate the final output.
        Parameters:
        sharedBlocks - the sharedBlocks of feature transformer
        outDim - the output dimension of feature transformer
        numIndependent - the number of independent blocks of feature transformer
        virtualBatchSize - the virtual batch size for ghost batch norm
        batchNormMomentum - the momentum for batch norm layer
        Returns:
        a feature transformer
      • forwardInternal

        protected ai.djl.ndarray.NDList forwardInternal​(ai.djl.training.ParameterStore parameterStore,
                                                        ai.djl.ndarray.NDList inputs,
                                                        boolean training,
                                                        ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
        Specified by:
        forwardInternal in class ai.djl.nn.AbstractBaseBlock
      • getOutputShapes

        public ai.djl.ndarray.types.Shape[] getOutputShapes​(ai.djl.ndarray.types.Shape[] inputShapes)
      • initializeChildBlocks

        protected void initializeChildBlocks​(ai.djl.ndarray.NDManager manager,
                                             ai.djl.ndarray.types.DataType dataType,
                                             ai.djl.ndarray.types.Shape... inputShapes)
        Overrides:
        initializeChildBlocks in class ai.djl.nn.AbstractBaseBlock
      • builder

        public static TabNet.Builder builder()
        Creates a builder to build a TabNet.
        Returns:
        a new builder