Class TabNet.AttentionTransformer

  • All Implemented Interfaces:
    ai.djl.nn.Block
    Enclosing class:
    TabNet

    public static final class TabNet.AttentionTransformer
    extends ai.djl.nn.AbstractBlock
    AttentionTransformer is where the tabNet models learn the relationship between relevant features, and decides which features to pass on to the feature transformer of the current decision step.
    • Field Summary

      • Fields inherited from class ai.djl.nn.AbstractBlock

        children, parameters
      • Fields inherited from class ai.djl.nn.AbstractBaseBlock

        inputNames, inputShapes, outputDataTypes, version
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      protected ai.djl.ndarray.NDList forwardInternal​(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
      ai.djl.ndarray.types.Shape[] getOutputShapes​(ai.djl.ndarray.types.Shape[] inputShapes)
      protected void initializeChildBlocks​(ai.djl.ndarray.NDManager manager, ai.djl.ndarray.types.DataType dataType, ai.djl.ndarray.types.Shape... inputShapes)
      • Methods inherited from class ai.djl.nn.AbstractBlock

        addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
      • Methods inherited from class ai.djl.nn.AbstractBaseBlock

        beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
      • Methods inherited from interface ai.djl.nn.Block

        forward, freezeParameters, getOutputShapes
    • Method Detail

      • forwardInternal

        protected ai.djl.ndarray.NDList forwardInternal​(ai.djl.training.ParameterStore parameterStore,
                                                        ai.djl.ndarray.NDList inputs,
                                                        boolean training,
                                                        ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
        Specified by:
        forwardInternal in class ai.djl.nn.AbstractBaseBlock
      • getOutputShapes

        public ai.djl.ndarray.types.Shape[] getOutputShapes​(ai.djl.ndarray.types.Shape[] inputShapes)
      • initializeChildBlocks

        protected void initializeChildBlocks​(ai.djl.ndarray.NDManager manager,
                                             ai.djl.ndarray.types.DataType dataType,
                                             ai.djl.ndarray.types.Shape... inputShapes)
        Overrides:
        initializeChildBlocks in class ai.djl.nn.AbstractBaseBlock