Class TabNet.DecisionStep

java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.basicmodelzoo.tabular.TabNet.DecisionStep
All Implemented Interfaces:
ai.djl.nn.Block
Enclosing class:
TabNet

public static final class TabNet.DecisionStep extends ai.djl.nn.AbstractBlock
DecisionStep is just combining featureTransformer and attentionTransformer together.
  • Field Summary

    Fields inherited from class ai.djl.nn.AbstractBlock

    children, parameters

    Fields inherited from class ai.djl.nn.AbstractBaseBlock

    inputNames, inputShapes, outputDataTypes, version
  • Constructor Summary

    Constructors
    Constructor
    Description
    DecisionStep(int inputDim, int numD, int numA, List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum)
    Creates a TabNet.DecisionStep with given parameters.
  • Method Summary

    Modifier and Type
    Method
    Description
    protected ai.djl.ndarray.NDList
    forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
    ai.djl.ndarray.types.Shape[]
    getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
    protected void
    initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)

    Methods inherited from class ai.djl.nn.AbstractBlock

    addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters

    Methods inherited from class ai.djl.nn.AbstractBaseBlock

    beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait

    Methods inherited from interface ai.djl.nn.Block

    forward, freezeParameters, freezeParameters, getOutputShapes
  • Constructor Details

    • DecisionStep

      public DecisionStep(int inputDim, int numD, int numA, List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum)
      Creates a TabNet.DecisionStep with given parameters.
      Parameters:
      inputDim - the number of input dimension for attentionTransformer
      numD - the number of dimension except attentionTransformer
      numA - the number of dimension for attentionTransformer
      shared - the shared fullyConnected layers
      nInd - the number of independent fullyConnected layers
      virtualBatchSize - the virtual batch size
      batchNormMomentum - the momentum for batchNorm layer
  • Method Details

    • forwardInternal

      protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class ai.djl.nn.AbstractBaseBlock
    • getOutputShapes

      public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
    • initializeChildBlocks

      protected void initializeChildBlocks(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
      Overrides:
      initializeChildBlocks in class ai.djl.nn.AbstractBaseBlock