Class TabNet

All Implemented Interfaces:

public final class TabNet extends ai.djl.nn.AbstractBlock
TabNet contains a generic implementation of TabNet adapted from (Original author Samrat Thapa)

TabNet is a neural architecture for tabular dataset developed by the research team at Google Cloud AI. It was able to achieve state_of_the_art results on several datasets in both regression and classification problems. Another desirable feature of TabNet is interpretability. Contrary to most of deep learning, where the neural networks act like black boxes, we can interpret which features the models selects in case of TabNet.

see for more information about TabNet

  • Nested Class Summary

    Nested Classes
    Modifier and Type
    static final class 
    AttentionTransformer is where the tabNet models learn the relationship between relevant features, and decides which features to pass on to the feature transformer of the current decision step.
    static class 
    The Builder to construct a TabNet object.
    static final class 
    DecisionStep is just combining featureTransformer and attentionTransformer together.
  • Field Summary

    Fields inherited from class ai.djl.nn.AbstractBlock

    children, parameters

    Fields inherited from class ai.djl.nn.AbstractBaseBlock

    inputNames, inputShapes, outputDataTypes, version
  • Method Summary

    Modifier and Type
    Creates a builder to build a TabNet.
    static ai.djl.nn.Block
    featureTransformer(List<ai.djl.nn.Block> sharedBlocks, int outDim, int numIndependent, int virtualBatchSize, float batchNormMomentum)
    Creates a featureTransformer Block.
    protected ai.djl.ndarray.NDList
    forwardInternal( parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
    getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
    static ai.djl.nn.Block
    gluBlock(ai.djl.nn.Block sharedBlock, int outDim, int virtualBatchSize, float batchNormMomentum)
    Creates a FC-BN-GLU block used in tabNet.
    protected void
    initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
    static ai.djl.ndarray.NDArray
    tabNetGLU(ai.djl.ndarray.NDArray array, int units)
    Applies tabNetGLU activation(which is mostly used in tabNet) on the input NDArray.
    static ai.djl.ndarray.NDList
    tabNetGLU(ai.djl.ndarray.NDList arrays, int units)
    Applies tabNetGLU activation(which is mostly used in tabNet) on the input singleton NDList.
    static ai.djl.nn.Block
    tabNetGLUBlock(int units)
    Creates a LambdaBlock that applies the tabNetGLU(NDArray, int) activation function in its forward function.

    Methods inherited from class ai.djl.nn.AbstractBlock

    addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters

    Methods inherited from class ai.djl.nn.AbstractBaseBlock

    beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait

    Methods inherited from interface ai.djl.nn.Block

    forward, freezeParameters, freezeParameters, getOutputShapes
  • Method Details

    • tabNetGLU

      public static ai.djl.ndarray.NDArray tabNetGLU(ai.djl.ndarray.NDArray array, int units)
      Applies tabNetGLU activation(which is mostly used in tabNet) on the input NDArray.
      array - the input NDArray
      units - the half number of the resultant features
      the NDArray after applying tabNetGLU function
    • tabNetGLU

      public static ai.djl.ndarray.NDList tabNetGLU(ai.djl.ndarray.NDList arrays, int units)
      Applies tabNetGLU activation(which is mostly used in tabNet) on the input singleton NDList.
      arrays - the input singleton NDList
      units - the half number of the resultant features
      the singleton NDList after applying tabNetGLU function
    • tabNetGLUBlock

      public static ai.djl.nn.Block tabNetGLUBlock(int units)
      Creates a LambdaBlock that applies the tabNetGLU(NDArray, int) activation function in its forward function.
      units - the half number of feature
      LambdaBlock that applies the tabNetGLU(NDArray, int) activation function
    • gluBlock

      public static ai.djl.nn.Block gluBlock(ai.djl.nn.Block sharedBlock, int outDim, int virtualBatchSize, float batchNormMomentum)
      Creates a FC-BN-GLU block used in tabNet. In order to do GLU, we double the dimension of the input features to the GLU using a fc layer.
      sharedBlock - the shared fully connected layer
      outDim - the output feature dimension
      virtualBatchSize - the virtualBatchSize
      batchNormMomentum - the momentum used for ghost batchNorm layer
      a FC-BN-GLU block
    • featureTransformer

      public static ai.djl.nn.Block featureTransformer(List<ai.djl.nn.Block> sharedBlocks, int outDim, int numIndependent, int virtualBatchSize, float batchNormMomentum)
      Creates a featureTransformer Block. The feature transformer is where all the selected features are processed to generate the final output.
      sharedBlocks - the sharedBlocks of feature transformer
      outDim - the output dimension of feature transformer
      numIndependent - the number of independent blocks of feature transformer
      virtualBatchSize - the virtual batch size for ghost batch norm
      batchNormMomentum - the momentum for batch norm layer
      a feature transformer
    • forwardInternal

      protected ai.djl.ndarray.NDList forwardInternal( parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class ai.djl.nn.AbstractBaseBlock
    • getOutputShapes

      public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
    • initializeChildBlocks

      protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
      initializeChildBlocks in class ai.djl.nn.AbstractBaseBlock
    • builder

      public static TabNet.Builder builder()
      Creates a builder to build a TabNet.
      a new builder