Package ai.djl.basicmodelzoo.tabular
Class TabNet
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.basicmodelzoo.tabular.TabNet
-
- All Implemented Interfaces:
ai.djl.nn.Block
public final class TabNet extends ai.djl.nn.AbstractBlock
TabNet
contains a generic implementation of TabNet adapted from https://towardsdatascience.com/implementing-tabnet-in-pytorch-fc977c383279 (Original author Samrat Thapa)TabNet is a neural architecture for tabular dataset developed by the research team at Google Cloud AI. It was able to achieve state_of_the_art results on several datasets in both regression and classification problems. Another desirable feature of TabNet is interpretability. Contrary to most of deep learning, where the neural networks act like black boxes, we can interpret which features the models selects in case of TabNet.
see https://arxiv.org/pdf/1908.07442.pdf for more information about TabNet
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
TabNet.AttentionTransformer
AttentionTransformer is where the tabNet models learn the relationship between relevant features, and decides which features to pass on to the feature transformer of the current decision step.static class
TabNet.Builder
The Builder to construct aTabNet
object.static class
TabNet.DecisionStep
DecisionStep is just combining featureTransformer and attentionTransformer together.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static TabNet.Builder
builder()
Creates a builder to build aTabNet
.static ai.djl.nn.Block
featureTransformer(java.util.List<ai.djl.nn.Block> sharedBlocks, int outDim, int numIndependent, int virtualBatchSize, float batchNormMomentum)
Creates a featureTransformer Block.protected ai.djl.ndarray.NDList
forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
ai.djl.ndarray.types.Shape[]
getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
static ai.djl.nn.Block
gluBlock(ai.djl.nn.Block sharedBlock, int outDim, int virtualBatchSize, float batchNormMomentum)
Creates a FC-BN-GLU block used in tabNet.protected void
initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
static ai.djl.ndarray.NDArray
tabNetGLU(ai.djl.ndarray.NDArray array, int units)
Applies tabNetGLU activation(which is mostly used in tabNet) on the inputNDArray
.static ai.djl.ndarray.NDList
tabNetGLU(ai.djl.ndarray.NDList arrays, int units)
Applies tabNetGLU activation(which is mostly used in tabNet) on the input singletonNDList
.static ai.djl.nn.Block
tabNetGLUBlock(int units)
Creates aLambdaBlock
that applies thetabNetGLU(NDArray, int)
activation function in its forward function.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
-
-
-
Method Detail
-
tabNetGLU
public static ai.djl.ndarray.NDArray tabNetGLU(ai.djl.ndarray.NDArray array, int units)
Applies tabNetGLU activation(which is mostly used in tabNet) on the inputNDArray
.- Parameters:
array
- the inputNDArray
units
- the half number of the resultant features- Returns:
- the
NDArray
after applying tabNetGLU function
-
tabNetGLU
public static ai.djl.ndarray.NDList tabNetGLU(ai.djl.ndarray.NDList arrays, int units)
Applies tabNetGLU activation(which is mostly used in tabNet) on the input singletonNDList
.- Parameters:
arrays
- the input singletonNDList
units
- the half number of the resultant features- Returns:
- the singleton
NDList
after applying tabNetGLU function
-
tabNetGLUBlock
public static ai.djl.nn.Block tabNetGLUBlock(int units)
Creates aLambdaBlock
that applies thetabNetGLU(NDArray, int)
activation function in its forward function.- Parameters:
units
- the half number of feature- Returns:
LambdaBlock
that applies thetabNetGLU(NDArray, int)
activation function
-
gluBlock
public static ai.djl.nn.Block gluBlock(ai.djl.nn.Block sharedBlock, int outDim, int virtualBatchSize, float batchNormMomentum)
Creates a FC-BN-GLU block used in tabNet. In order to do GLU, we double the dimension of the input features to the GLU using a fc layer.- Parameters:
sharedBlock
- the shared fully connected layeroutDim
- the output feature dimensionvirtualBatchSize
- the virtualBatchSizebatchNormMomentum
- the momentum used for ghost batchNorm layer- Returns:
- a FC-BN-GLU block
-
featureTransformer
public static ai.djl.nn.Block featureTransformer(java.util.List<ai.djl.nn.Block> sharedBlocks, int outDim, int numIndependent, int virtualBatchSize, float batchNormMomentum)
Creates a featureTransformer Block. The feature transformer is where all the selected features are processed to generate the final output.- Parameters:
sharedBlocks
- the sharedBlocks of feature transformeroutDim
- the output dimension of feature transformernumIndependent
- the number of independent blocks of feature transformervirtualBatchSize
- the virtual batch size for ghost batch normbatchNormMomentum
- the momentum for batch norm layer- Returns:
- a feature transformer
-
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
- Specified by:
forwardInternal
in classai.djl.nn.AbstractBaseBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
-
initializeChildBlocks
protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
- Overrides:
initializeChildBlocks
in classai.djl.nn.AbstractBaseBlock
-
builder
public static TabNet.Builder builder()
Creates a builder to build aTabNet
.- Returns:
- a new builder
-
-