Package ai.djl.basicmodelzoo.tabular
Class TabNet.DecisionStep
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.basicmodelzoo.tabular.TabNet.DecisionStep
- All Implemented Interfaces:
ai.djl.nn.Block
- Enclosing class:
- TabNet
public static final class TabNet.DecisionStep
extends ai.djl.nn.AbstractBlock
DecisionStep is just combining featureTransformer and attentionTransformer together.
-
Field Summary
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Constructor Summary
ConstructorDescriptionDecisionStep
(int inputDim, int numD, int numA, List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum) Creates aTabNet.DecisionStep
with given parameters. -
Method Summary
Modifier and TypeMethodDescriptionprotected ai.djl.ndarray.NDList
forwardInternal
(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) ai.djl.ndarray.types.Shape[]
getOutputShapes
(ai.djl.ndarray.types.Shape[] inputShapes) protected void
initializeChildBlocks
(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes) Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Constructor Details
-
DecisionStep
public DecisionStep(int inputDim, int numD, int numA, List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum) Creates aTabNet.DecisionStep
with given parameters.- Parameters:
inputDim
- the number of input dimension for attentionTransformernumD
- the number of dimension except attentionTransformernumA
- the number of dimension for attentionTransformershared
- the shared fullyConnected layersnInd
- the number of independent fullyConnected layersvirtualBatchSize
- the virtual batch sizebatchNormMomentum
- the momentum for batchNorm layer
-
-
Method Details
-
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) - Specified by:
forwardInternal
in classai.djl.nn.AbstractBaseBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes) -
initializeChildBlocks
protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes) - Overrides:
initializeChildBlocks
in classai.djl.nn.AbstractBaseBlock
-