Package ai.djl.basicmodelzoo.tabular
Class TabNet.Builder
- java.lang.Object
-
- ai.djl.basicmodelzoo.tabular.TabNet.Builder
-
-
Constructor Summary
Constructors Constructor Description Builder()
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description ai.djl.nn.Block
build()
Builds a TabNet with givenBuilder
.ai.djl.nn.Block
buildAttentionTransformer(int units)
Builds an attentionTransformer with given parameter for test.TabNet.Builder
optBatchNormMomentum(float batchNormMomentum)
Sets the momentum for batchNorm layer.TabNet.Builder
optNumA(int numA)
Sets the number of dimension for attentionTransformer.TabNet.Builder
optNumD(int numD)
Sets the number of dimension except attentionTransformer.TabNet.Builder
optNumIndependent(int numIndependent)
Sets the number of independent fullyConnected layers.TabNet.Builder
optNumShared(int numShared)
Sets the number of shared fullyConnected layers.TabNet.Builder
optNumSteps(int numSteps)
Sets the number of decision steps for tabNet.TabNet.Builder
optVirtualBatchSize(int virtualBatchSize)
Sets the virtual batch size for ghost batch norm.TabNet.Builder
setInputDim(int inputDim)
Sets the input dimension of TabNet.TabNet.Builder
setOutDim(int outDim)
Sets the output dimension for TabNet.
-
-
-
Method Detail
-
setInputDim
public TabNet.Builder setInputDim(int inputDim)
Sets the input dimension of TabNet.- Parameters:
inputDim
- the input dimension- Returns:
- this
Builder
-
setOutDim
public TabNet.Builder setOutDim(int outDim)
Sets the output dimension for TabNet.- Parameters:
outDim
- the output dimension- Returns:
- this
Builder
-
optNumD
public TabNet.Builder optNumD(int numD)
Sets the number of dimension except attentionTransformer.- Parameters:
numD
- the number of dimension except attentionTransformer- Returns:
- this
Builder
-
optNumA
public TabNet.Builder optNumA(int numA)
Sets the number of dimension for attentionTransformer.- Parameters:
numA
- the number of dimension for attentionTransformer- Returns:
- this
Builder
-
optNumShared
public TabNet.Builder optNumShared(int numShared)
Sets the number of shared fullyConnected layers.- Parameters:
numShared
- the number of shared fullyConnected layers- Returns:
- this
Builder
-
optNumIndependent
public TabNet.Builder optNumIndependent(int numIndependent)
Sets the number of independent fullyConnected layers.- Parameters:
numIndependent
- the number of independent fullyConnected layers- Returns:
- this
Builder
-
optNumSteps
public TabNet.Builder optNumSteps(int numSteps)
Sets the number of decision steps for tabNet.- Parameters:
numSteps
- the number of decision steps for tabNet- Returns:
- this
Builder
-
optVirtualBatchSize
public TabNet.Builder optVirtualBatchSize(int virtualBatchSize)
Sets the virtual batch size for ghost batch norm.- Parameters:
virtualBatchSize
- the virtual batch size- Returns:
- this
Builder
-
optBatchNormMomentum
public TabNet.Builder optBatchNormMomentum(float batchNormMomentum)
Sets the momentum for batchNorm layer.- Parameters:
batchNormMomentum
- the momentum for batchNormLayer- Returns:
- this
Builder
-
buildAttentionTransformer
public ai.djl.nn.Block buildAttentionTransformer(int units)
Builds an attentionTransformer with given parameter for test.- Parameters:
units
- the number of test units- Returns:
- an attentionTransformer Block
-
build
public ai.djl.nn.Block build()
Builds a TabNet with givenBuilder
.- Returns:
- a tabNetBlock
-
-