Class TabNet.AttentionTransformer

java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.basicmodelzoo.tabular.TabNet.AttentionTransformer
All Implemented Interfaces:
ai.djl.nn.Block
Enclosing class:
TabNet

public static final class TabNet.AttentionTransformer extends ai.djl.nn.AbstractBlock
AttentionTransformer is where the tabNet models learn the relationship between relevant features, and decides which features to pass on to the feature transformer of the current decision step.
  • Field Summary

    Fields inherited from class ai.djl.nn.AbstractBlock

    children, parameters

    Fields inherited from class ai.djl.nn.AbstractBaseBlock

    inputNames, inputShapes, outputDataTypes, version
  • Method Summary

    Modifier and Type
    Method
    Description
    protected ai.djl.ndarray.NDList
    forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
    ai.djl.ndarray.types.Shape[]
    getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
    protected void
    initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)

    Methods inherited from class ai.djl.nn.AbstractBlock

    addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters

    Methods inherited from class ai.djl.nn.AbstractBaseBlock

    beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait

    Methods inherited from interface ai.djl.nn.Block

    forward, freezeParameters, freezeParameters, getOutputShapes
  • Method Details

    • forwardInternal

      protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class ai.djl.nn.AbstractBaseBlock
    • getOutputShapes

      public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
    • initializeChildBlocks

      protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
      Overrides:
      initializeChildBlocks in class ai.djl.nn.AbstractBaseBlock