Class Dropout
- All Implemented Interfaces:
Block
The idea of dropout itself was proposed in 2014, with the purpose of improving the performance of large networks due to co-adaptation, where some connections are stronger and learned more while other connections become weaker and loses their impact on the prediction, resulting in network overfitting. It was also created as an alternative for costly networks, such as large or ensemble networks, by removing several units, hence creating different thinned network architectures and simulates multiple networks within a single network, greatly reducing the computation cost.
Dropout is recommended to be used when one is trying to optimize an overfitting network or when large dataset is available. It is still quite commonly used in many publications due to its generalization capability. However, using dropout may not prevent overfitting due to variation and limited size of the dataset, and it is reported that dropout layer increases training time by 2-3 times since different simulated multiple networks are trained for each iteration, thus resulting in noisy parameter updates.
- See Also:
-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic final class
-
Field Summary
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Method Summary
Modifier and TypeMethodDescriptionstatic Dropout.Builder
builder()
Creates a builder to build aDropout
.static NDList
Applies Dropout to the input.static NDList
Applies Dropout to the input.static NDList
Applies Dropout to the input.protected NDList
forwardInternal
(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Shape[]
getOutputShapes
(Shape[] inputShapes) Returns the expected output shapes of the block for the specified input shapes.void
loadMetadata
(byte loadVersion, DataInputStream is) Overwrite this to load additional metadata with the parameter values.Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Method Details
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
getOutputShapes
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
loadMetadata
public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException Overwrite this to load additional metadata with the parameter values.If you overwrite
AbstractBaseBlock.saveMetadata(DataOutputStream)
or need to provide backward compatibility to older binary formats, you probably need to overwrite this. This default implementation checks if the version number fits, if not it throws anMalformedModelException
. After that it restores the input shapes.- Overrides:
loadMetadata
in classAbstractBaseBlock
- Parameters:
loadVersion
- the version used for loading this metadata.is
- the input stream we are loading from- Throws:
IOException
- loading failedMalformedModelException
- data can be loaded but has wrong format
-
dropout
Applies Dropout to the input.- Parameters:
input
- input to apply dropout- Returns:
- output
-
dropout
Applies Dropout to the input.- Parameters:
input
- input to apply dropoutrate
- Fraction of the input units to drop- Returns:
- output
-
dropout
Applies Dropout to the input.- Parameters:
input
- input to apply dropoutrate
- Fraction of the input units to droptraining
- apply dropout if true- Returns:
- output
-
builder
Creates a builder to build aDropout
.- Returns:
- a new builder
-