Package ai.djl.basicmodelzoo.tabular
Class TabNet.Builder
java.lang.Object
ai.djl.basicmodelzoo.tabular.TabNet.Builder
- Enclosing class:
- TabNet
The Builder to construct a
TabNet
object.-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionai.djl.nn.Block
build()
Builds a TabNet with givenBuilder
.ai.djl.nn.Block
buildAttentionTransformer
(int units) Builds an attentionTransformer with given parameter for test.optBatchNormMomentum
(float batchNormMomentum) Sets the momentum for batchNorm layer.optNumA
(int numA) Sets the number of dimension for attentionTransformer.optNumD
(int numD) Sets the number of dimension except attentionTransformer.optNumIndependent
(int numIndependent) Sets the number of independent fullyConnected layers.optNumShared
(int numShared) Sets the number of shared fullyConnected layers.optNumSteps
(int numSteps) Sets the number of decision steps for tabNet.optVirtualBatchSize
(int virtualBatchSize) Sets the virtual batch size for ghost batch norm.setInputDim
(int inputDim) Sets the input dimension of TabNet.setOutDim
(int outDim) Sets the output dimension for TabNet.
-
Constructor Details
-
Builder
public Builder()
-
-
Method Details
-
setInputDim
Sets the input dimension of TabNet.- Parameters:
inputDim
- the input dimension- Returns:
- this
Builder
-
setOutDim
Sets the output dimension for TabNet.- Parameters:
outDim
- the output dimension- Returns:
- this
Builder
-
optNumD
Sets the number of dimension except attentionTransformer.- Parameters:
numD
- the number of dimension except attentionTransformer- Returns:
- this
Builder
-
optNumA
Sets the number of dimension for attentionTransformer.- Parameters:
numA
- the number of dimension for attentionTransformer- Returns:
- this
Builder
-
optNumIndependent
Sets the number of independent fullyConnected layers.- Parameters:
numIndependent
- the number of independent fullyConnected layers- Returns:
- this
Builder
-
optNumSteps
Sets the number of decision steps for tabNet.- Parameters:
numSteps
- the number of decision steps for tabNet- Returns:
- this
Builder
-
optVirtualBatchSize
Sets the virtual batch size for ghost batch norm.- Parameters:
virtualBatchSize
- the virtual batch size- Returns:
- this
Builder
-
optBatchNormMomentum
Sets the momentum for batchNorm layer.- Parameters:
batchNormMomentum
- the momentum for batchNormLayer- Returns:
- this
Builder
-
buildAttentionTransformer
public ai.djl.nn.Block buildAttentionTransformer(int units) Builds an attentionTransformer with given parameter for test.- Parameters:
units
- the number of test units- Returns:
- an attentionTransformer Block
-
build
public ai.djl.nn.Block build()Builds a TabNet with givenBuilder
.- Returns:
- a tabNetBlock
-