public class Dropout extends AbstractBlock
The idea of dropout itself was proposed in 2014, with the purpose of improving the performance of large networks due to co-adaptation, where some connections are stronger and learned more while other connections become weaker and loses their impact on the prediction, resulting in network overfitting. It was also created as an alternative for costly networks, such as large or ensemble networks, by removing several units, hence creating different thinned network architectures and simulates multiple networks within a single network, greatly reducing the computation cost.
Dropout is recommended to be used when one is trying to optimize an overfitting network or when large dataset is available. It is still quite commonly used in many publications due to its generalization capability. However, using dropout may not prevent overfitting due to variation and limited size of the dataset, and it is reported that dropout layer increases training time by 2-3 times since different simulated multiple networks are trained for each iteration, thus resulting in noisy parameter updates.
Modifier and Type | Class and Description |
---|---|
static class |
Dropout.Builder
|
children, inputNames, inputShapes, parameters, version
Modifier and Type | Method and Description |
---|---|
static Dropout.Builder |
builder()
Creates a builder to build a
Dropout . |
static NDList |
dropout(NDArray input)
Applies Dropout to the input.
|
static NDList |
dropout(NDArray input,
float rate)
Applies Dropout to the input.
|
static NDList |
dropout(NDArray input,
float rate,
boolean training)
Applies Dropout to the input.
|
protected NDList |
forwardInternal(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper for
Block.forward(ParameterStore, NDList, boolean, PairList) after
initialization. |
Shape[] |
getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.
|
void |
loadMetadata(byte version,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
java.lang.String |
toString() |
addChildBlock, addParameter, beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, validateLayout
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Block.forward(ParameterStore, NDList, boolean, PairList)
after
initialization.forwardInternal
in class AbstractBlock
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameterspublic Shape[] getOutputShapes(Shape[] inputShapes)
inputShapes
- the shapes of the inputspublic void loadMetadata(byte version, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream)
or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException
. After that it restores the input shapes.
loadMetadata
in class AbstractBlock
version
- the version used for loading this metadata.is
- the input stream we are loading fromjava.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong formatpublic java.lang.String toString()
toString
in class AbstractBlock
public static NDList dropout(NDArray input)
input
- input to apply dropoutpublic static NDList dropout(NDArray input, float rate)
input
- input to apply dropoutrate
- Fraction of the input units to droppublic static NDList dropout(NDArray input, float rate, boolean training)
input
- input to apply dropoutrate
- Fraction of the input units to droptraining
- apply dropout if truepublic static Dropout.Builder builder()
Dropout
.