Package ai.djl.basicmodelzoo.tabular
Class TabNet.DecisionStep
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.basicmodelzoo.tabular.TabNet.DecisionStep
-
- All Implemented Interfaces:
ai.djl.nn.Block
- Enclosing class:
- TabNet
public static final class TabNet.DecisionStep extends ai.djl.nn.AbstractBlock
DecisionStep is just combining featureTransformer and attentionTransformer together.
-
-
Constructor Summary
Constructors Constructor Description DecisionStep(int inputDim, int numD, int numA, java.util.List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum)
Creates aTabNet.DecisionStep
with given parameters.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description protected ai.djl.ndarray.NDList
forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
ai.djl.ndarray.types.Shape[]
getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
protected void
initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
-
-
-
Constructor Detail
-
DecisionStep
public DecisionStep(int inputDim, int numD, int numA, java.util.List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum)
Creates aTabNet.DecisionStep
with given parameters.- Parameters:
inputDim
- the number of input dimension for attentionTransformernumD
- the number of dimension except attentionTransformernumA
- the number of dimension for attentionTransformershared
- the shared fullyConnected layersnInd
- the number of independent fullyConnected layersvirtualBatchSize
- the virtual batch sizebatchNormMomentum
- the momentum for batchNorm layer
-
-
Method Detail
-
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
- Specified by:
forwardInternal
in classai.djl.nn.AbstractBaseBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
-
initializeChildBlocks
protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
- Overrides:
initializeChildBlocks
in classai.djl.nn.AbstractBaseBlock
-
-