Package ai.djl.basicmodelzoo.tabular
Class TabNet.AttentionTransformer
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.basicmodelzoo.tabular.TabNet.AttentionTransformer
- All Implemented Interfaces:
ai.djl.nn.Block
- Enclosing class:
- TabNet
public static final class TabNet.AttentionTransformer
extends ai.djl.nn.AbstractBlock
AttentionTransformer is where the tabNet models learn the relationship between relevant
features, and decides which features to pass on to the feature transformer of the current
decision step.
-
Field Summary
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Method Summary
Modifier and TypeMethodDescriptionprotected ai.djl.ndarray.NDList
forwardInternal
(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) ai.djl.ndarray.types.Shape[]
getOutputShapes
(ai.djl.ndarray.types.Shape[] inputShapes) protected void
initializeChildBlocks
(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes) Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Method Details
-
forwardInternal
protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) - Specified by:
forwardInternal
in classai.djl.nn.AbstractBaseBlock
-
getOutputShapes
public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes) -
initializeChildBlocks
protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes) - Overrides:
initializeChildBlocks
in classai.djl.nn.AbstractBaseBlock
-