Class TabNet.DecisionStep

  • All Implemented Interfaces:
    ai.djl.nn.Block
    Enclosing class:
    TabNet

    public static final class TabNet.DecisionStep
    extends ai.djl.nn.AbstractBlock
    DecisionStep is just combining featureTransformer and attentionTransformer together.
    • Field Summary

      • Fields inherited from class ai.djl.nn.AbstractBlock

        children, parameters
      • Fields inherited from class ai.djl.nn.AbstractBaseBlock

        inputNames, inputShapes, outputDataTypes, version
    • Constructor Summary

      Constructors 
      Constructor Description
      DecisionStep​(int inputDim, int numD, int numA, java.util.List<ai.djl.nn.Block> shared, int nInd, int virtualBatchSize, float batchNormMomentum)
      Creates a TabNet.DecisionStep with given parameters.
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      protected ai.djl.ndarray.NDList forwardInternal​(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
      ai.djl.ndarray.types.Shape[] getOutputShapes​(ai.djl.ndarray.types.Shape[] inputShapes)
      protected void initializeChildBlocks​(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
      • Methods inherited from class ai.djl.nn.AbstractBlock

        addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
      • Methods inherited from class ai.djl.nn.AbstractBaseBlock

        beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
      • Methods inherited from interface ai.djl.nn.Block

        forward, freezeParameters, freezeParameters, getOutputShapes
    • Constructor Detail

      • DecisionStep

        public DecisionStep​(int inputDim,
                            int numD,
                            int numA,
                            java.util.List<ai.djl.nn.Block> shared,
                            int nInd,
                            int virtualBatchSize,
                            float batchNormMomentum)
        Creates a TabNet.DecisionStep with given parameters.
        Parameters:
        inputDim - the number of input dimension for attentionTransformer
        numD - the number of dimension except attentionTransformer
        numA - the number of dimension for attentionTransformer
        shared - the shared fullyConnected layers
        nInd - the number of independent fullyConnected layers
        virtualBatchSize - the virtual batch size
        batchNormMomentum - the momentum for batchNorm layer
    • Method Detail

      • forwardInternal

        protected ai.djl.ndarray.NDList forwardInternal​(ai.djl.training.ParameterStore parameterStore,
                                                        ai.djl.ndarray.NDList inputs,
                                                        boolean training,
                                                        ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
        Specified by:
        forwardInternal in class ai.djl.nn.AbstractBaseBlock
      • getOutputShapes

        public ai.djl.ndarray.types.Shape[] getOutputShapes​(ai.djl.ndarray.types.Shape[] inputShapes)
      • initializeChildBlocks

        protected void initializeChildBlocks​(ai.djl.ndarray.NDManager manager,
                                             ai.djl.ndarray.types.DataType dataType,
                                             ai.djl.ndarray.types.Shape... inputShapes)
        Overrides:
        initializeChildBlocks in class ai.djl.nn.AbstractBaseBlock