Package ai.djl.basicmodelzoo.tabular
Class TabNet
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.basicmodelzoo.tabular.TabNet
- All Implemented Interfaces:
ai.djl.nn.Block
public final class TabNet
extends ai.djl.nn.AbstractBlock
TabNet
contains a generic implementation of TabNet adapted from
https://towardsdatascience.com/implementing-tabnet-in-pytorch-fc977c383279 (Original author
Samrat Thapa)
TabNet is a neural architecture for tabular dataset developed by the research team at Google Cloud AI. It was able to achieve state_of_the_art results on several datasets in both regression and classification problems. Another desirable feature of TabNet is interpretability. Contrary to most of deep learning, where the neural networks act like black boxes, we can interpret which features the models selects in case of TabNet.
see https://arxiv.org/pdf/1908.07442.pdf for more information about TabNet
-
Nested Class Summary
Modifier and TypeClassDescriptionstatic final class
AttentionTransformer is where the tabNet models learn the relationship between relevant features, and decides which features to pass on to the feature transformer of the current decision step.static class
The Builder to construct aTabNet
object.static final class
DecisionStep is just combining featureTransformer and attentionTransformer together. -
Field Summary
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Method Summary
Modifier and TypeMethodDescriptionstatic TabNet.Builder
builder()
Creates a builder to build aTabNet
.static ai.djl.nn.Block
featureTransformer
(List<ai.djl.nn.Block> sharedBlocks, int outDim, int numIndependent, int virtualBatchSize, float batchNormMomentum) Creates a featureTransformer Block.protected ai.djl.ndarray.NDList
forwardInternal
(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) ai.djl.ndarray.types.Shape[]
getOutputShapes
(ai.djl.ndarray.types.Shape[] inputShapes) static ai.djl.nn.Block
gluBlock
(ai.djl.nn.Block sharedBlock, int outDim, int virtualBatchSize, float batchNormMomentum) Creates a FC-BN-GLU block used in tabNet.protected void
initializeChildBlocks
(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes) static ai.djl.ndarray.NDArray
tabNetGLU
(ai.djl.ndarray.NDArray array, int units) Applies tabNetGLU activation(which is mostly used in tabNet) on the inputNDArray
.static ai.djl.ndarray.NDList
tabNetGLU
(ai.djl.ndarray.NDList arrays, int units) Applies tabNetGLU activation(which is mostly used in tabNet) on the input singletonNDList
.static ai.djl.nn.Block
tabNetGLUBlock
(int units) Creates aLambdaBlock
that applies thetabNetGLU(NDArray, int)
activation function in its forward function.Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Method Details
-
tabNetGLU
public static ai.djl.ndarray.NDArray tabNetGLU(ai.djl.ndarray.NDArray array, int units) Applies tabNetGLU activation(which is mostly used in tabNet) on the inputNDArray
.- Parameters:
array
- the inputNDArray
units
- the half number of the resultant features- Returns:
- the
NDArray
after applying tabNetGLU function
-
tabNetGLU
public static ai.djl.ndarray.NDList tabNetGLU(ai.djl.ndarray.NDList arrays, int units) Applies tabNetGLU activation(which is mostly used in tabNet) on the input singletonNDList
.- Parameters:
arrays
- the input singletonNDList
units
- the half number of the resultant features- Returns:
- the singleton
NDList
after applying tabNetGLU function
-
tabNetGLUBlock
public static ai.djl.nn.Block tabNetGLUBlock(int units) Creates aLambdaBlock
that applies thetabNetGLU(NDArray, int)
activation function in its forward function.- Parameters:
units
- the half number of feature- Returns:
LambdaBlock
that applies thetabNetGLU(NDArray, int)
activation function
-
gluBlock
public static ai.djl.nn.Block gluBlock(ai.djl.nn.Block sharedBlock, int outDim, int virtualBatchSize, float batchNormMomentum) Creates a FC-BN-GLU block used in tabNet. In order to do GLU, we double the dimension of the input features to the GLU using a fc layer.- Parameters:
sharedBlock
- the shared fully connected layeroutDim
- the output feature dimensionvirtualBatchSize
- the virtualBatchSizebatchNormMomentum
- the momentum used for ghost batchNorm layer- Returns:
- a FC-BN-GLU block
-
featureTransformer
public static ai.djl.nn.Block featureTransformer(List<ai.djl.nn.Block> sharedBlocks, int outDim, int numIndependent, int virtualBatchSize, float batchNormMomentum) Creates a featureTransformer Block. The feature transformer is where all the selected features are processed to generate the final output.- Parameters:
sharedBlocks
- the sharedBlocks of feature transformeroutDim
- the output dimension of feature transformernumIndependent
- the number of independent blocks of feature transformervirtualBatchSize
- the virtual batch size for ghost batch normbatchNormMomentum
- the momentum for batch norm layer- Returns:
- a feature transformer
-
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) - Specified by:
forwardInternal
in classai.djl.nn.AbstractBaseBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes) -
initializeChildBlocks
protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes) - Overrides:
initializeChildBlocks
in classai.djl.nn.AbstractBaseBlock
-
builder
Creates a builder to build aTabNet
.- Returns:
- a new builder
-