Class TabNet.Builder

  • Enclosing class:
    TabNet

    public static class TabNet.Builder
    extends java.lang.Object
    The Builder to construct a TabNet object.
    • Constructor Detail

      • Builder

        public Builder()
    • Method Detail

      • setInputDim

        public TabNet.Builder setInputDim​(int inputDim)
        Sets the input dimension of TabNet.
        Parameters:
        inputDim - the input dimension
        Returns:
        this Builder
      • setOutDim

        public TabNet.Builder setOutDim​(int outDim)
        Sets the output dimension for TabNet.
        Parameters:
        outDim - the output dimension
        Returns:
        this Builder
      • optNumD

        public TabNet.Builder optNumD​(int numD)
        Sets the number of dimension except attentionTransformer.
        Parameters:
        numD - the number of dimension except attentionTransformer
        Returns:
        this Builder
      • optNumA

        public TabNet.Builder optNumA​(int numA)
        Sets the number of dimension for attentionTransformer.
        Parameters:
        numA - the number of dimension for attentionTransformer
        Returns:
        this Builder
      • optNumShared

        public TabNet.Builder optNumShared​(int numShared)
        Sets the number of shared fullyConnected layers.
        Parameters:
        numShared - the number of shared fullyConnected layers
        Returns:
        this Builder
      • optNumIndependent

        public TabNet.Builder optNumIndependent​(int numIndependent)
        Sets the number of independent fullyConnected layers.
        Parameters:
        numIndependent - the number of independent fullyConnected layers
        Returns:
        this Builder
      • optNumSteps

        public TabNet.Builder optNumSteps​(int numSteps)
        Sets the number of decision steps for tabNet.
        Parameters:
        numSteps - the number of decision steps for tabNet
        Returns:
        this Builder
      • optVirtualBatchSize

        public TabNet.Builder optVirtualBatchSize​(int virtualBatchSize)
        Sets the virtual batch size for ghost batch norm.
        Parameters:
        virtualBatchSize - the virtual batch size
        Returns:
        this Builder
      • optBatchNormMomentum

        public TabNet.Builder optBatchNormMomentum​(float batchNormMomentum)
        Sets the momentum for batchNorm layer.
        Parameters:
        batchNormMomentum - the momentum for batchNormLayer
        Returns:
        this Builder
      • buildAttentionTransformer

        public ai.djl.nn.Block buildAttentionTransformer​(int units)
        Builds an attentionTransformer with given parameter for test.
        Parameters:
        units - the number of test units
        Returns:
        an attentionTransformer Block
      • build

        public ai.djl.nn.Block build()
        Builds a TabNet with given Builder.
        Returns:
        a tabNetBlock