Class Deconvolution
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.convolutional.Deconvolution
-
- All Implemented Interfaces:
Block
- Direct Known Subclasses:
Conv1dTranspose,Conv2dTranspose
public abstract class Deconvolution extends AbstractBlock
Transposed convolution, also named fractionally-strided convolution Dumoulin & Visin or deconvolution Long et al., 2015, serves this purpose.The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction of a normal convolution, i.e., from something that has the shape of the output of some convolution to something that has the shape of its input while maintaining a connectivity pattern that is compatible with said convolution.
Current implementations of
DeconvolutionareConv1dTransposewith input dimension ofLayoutType.WIDTHandConv2dTransposewith input dimension ofLayoutType.WIDTHandLayoutType.HEIGHT. These implementations share the same core principal as aDeconvolutionlayer does, with the difference being the number of input dimension each operates on as denoted byConvXdTransposeforXdimension(s).
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static classDeconvolution.DeconvolutionBuilder<T extends Deconvolution.DeconvolutionBuilder>A builder that can build anyDeconvolutionblock.
-
Field Summary
Fields Modifier and Type Field Description protected Parameterbiasprotected Shapedilationprotected intfiltersprotected intgroupsprotected booleanincludeBiasprotected ShapekernelShapeprotected ShapeoutPaddingprotected Shapepaddingprotected Shapestrideprotected Parameterweight-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, version
-
-
Constructor Summary
Constructors Constructor Description Deconvolution(Deconvolution.DeconvolutionBuilder<?> builder)Creates aDeconvolutionobject.
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description protected voidbeforeInitialize(Shape... inputShapes)Performs any action necessary before initialization.protected NDListforwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.protected abstract LayoutType[]getExpectedLayout()Returns the expected layout of the input.Shape[]getOutputShapes(Shape[] inputs)Returns the expected output shapes of the block for the specified input shapes.protected abstract java.lang.StringgetStringLayout()Returns the string representing the layout of the input.voidloadMetadata(byte loadVersion, java.io.DataInputStream is)Overwrite this to load additional metadata with the parameter values.protected abstract intnumDimensions()Returns the number of dimensions of the input.protected voidprepare(Shape[] inputs)Sets the shape ofParameters.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getParameters, initialize, initializeChildBlocks, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters
-
-
-
-
Field Detail
-
kernelShape
protected Shape kernelShape
-
stride
protected Shape stride
-
padding
protected Shape padding
-
outPadding
protected Shape outPadding
-
dilation
protected Shape dilation
-
filters
protected int filters
-
groups
protected int groups
-
includeBias
protected boolean includeBias
-
weight
protected Parameter weight
-
bias
protected Parameter bias
-
-
Constructor Detail
-
Deconvolution
public Deconvolution(Deconvolution.DeconvolutionBuilder<?> builder)
Creates aDeconvolutionobject.- Parameters:
builder- theBuilderthat has the necessary configurations
-
-
Method Detail
-
getExpectedLayout
protected abstract LayoutType[] getExpectedLayout()
Returns the expected layout of the input.- Returns:
- the expected layout of the input
-
getStringLayout
protected abstract java.lang.String getStringLayout()
Returns the string representing the layout of the input.- Returns:
- the string representing the layout of the input
-
numDimensions
protected abstract int numDimensions()
Returns the number of dimensions of the input.- Returns:
- the number of dimensions of the input
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.- Specified by:
forwardInternalin classAbstractBaseBlock- Parameters:
parameterStore- the parameter storeinputs- the input NDListtraining- true for a training forward passparams- optional parameters- Returns:
- the output of the forward pass
-
beforeInitialize
protected void beforeInitialize(Shape... inputShapes)
Performs any action necessary before initialization. For example, keep the input information or verify the layout.- Overrides:
beforeInitializein classAbstractBaseBlock- Parameters:
inputShapes- the expected shapes of the input
-
prepare
protected void prepare(Shape[] inputs)
Sets the shape ofParameters.- Overrides:
preparein classAbstractBaseBlock- Parameters:
inputs- the shapes of inputs
-
getOutputShapes
public Shape[] getOutputShapes(Shape[] inputs)
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputs- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
loadMetadata
public void loadMetadata(byte loadVersion, java.io.DataInputStream is) throws java.io.IOException, MalformedModelExceptionOverwrite this to load additional metadata with the parameter values.If you overwrite
AbstractBaseBlock.saveMetadata(DataOutputStream)or need to provide backward compatibility to older binary formats, you prabably need to overwrite this. This default implementation checks if the version number fits, if not it throws anMalformedModelException. After that it restores the input shapes.- Overrides:
loadMetadatain classAbstractBaseBlock- Parameters:
loadVersion- the version used for loading this metadata.is- the input stream we are loading from- Throws:
java.io.IOException- loading failedMalformedModelException- data can be loaded but has wrong format
-
-