Class TabNet.Builder

java.lang.Object
ai.djl.basicmodelzoo.tabular.TabNet.Builder
Enclosing class:
TabNet

public static class TabNet.Builder extends Object
The Builder to construct a TabNet object.
  • Constructor Details

    • Builder

      public Builder()
  • Method Details

    • setInputDim

      public TabNet.Builder setInputDim(int inputDim)
      Sets the input dimension of TabNet.
      Parameters:
      inputDim - the input dimension
      Returns:
      this Builder
    • setOutDim

      public TabNet.Builder setOutDim(int outDim)
      Sets the output dimension for TabNet.
      Parameters:
      outDim - the output dimension
      Returns:
      this Builder
    • optNumD

      public TabNet.Builder optNumD(int numD)
      Sets the number of dimension except attentionTransformer.
      Parameters:
      numD - the number of dimension except attentionTransformer
      Returns:
      this Builder
    • optNumA

      public TabNet.Builder optNumA(int numA)
      Sets the number of dimension for attentionTransformer.
      Parameters:
      numA - the number of dimension for attentionTransformer
      Returns:
      this Builder
    • optNumShared

      public TabNet.Builder optNumShared(int numShared)
      Sets the number of shared fullyConnected layers.
      Parameters:
      numShared - the number of shared fullyConnected layers
      Returns:
      this Builder
    • optNumIndependent

      public TabNet.Builder optNumIndependent(int numIndependent)
      Sets the number of independent fullyConnected layers.
      Parameters:
      numIndependent - the number of independent fullyConnected layers
      Returns:
      this Builder
    • optNumSteps

      public TabNet.Builder optNumSteps(int numSteps)
      Sets the number of decision steps for tabNet.
      Parameters:
      numSteps - the number of decision steps for tabNet
      Returns:
      this Builder
    • optVirtualBatchSize

      public TabNet.Builder optVirtualBatchSize(int virtualBatchSize)
      Sets the virtual batch size for ghost batch norm.
      Parameters:
      virtualBatchSize - the virtual batch size
      Returns:
      this Builder
    • optBatchNormMomentum

      public TabNet.Builder optBatchNormMomentum(float batchNormMomentum)
      Sets the momentum for batchNorm layer.
      Parameters:
      batchNormMomentum - the momentum for batchNormLayer
      Returns:
      this Builder
    • buildAttentionTransformer

      public ai.djl.nn.Block buildAttentionTransformer(int units)
      Builds an attentionTransformer with given parameter for test.
      Parameters:
      units - the number of test units
      Returns:
      an attentionTransformer Block
    • build

      public ai.djl.nn.Block build()
      Builds a TabNet with given Builder.
      Returns:
      a tabNetBlock