Class Deconvolution

All Implemented Interfaces:
Block
Direct Known Subclasses:
Conv1dTranspose, Conv2dTranspose

public abstract class Deconvolution extends AbstractBlock
Transposed convolution, also named fractionally-strided convolution Dumoulin & Visin or deconvolution Long et al., 2015, serves this purpose.

The need for transposed convolutions generally arises from the desire to use a transformation going in the opposite direction of a normal convolution, i.e., from something that has the shape of the output of some convolution to something that has the shape of its input while maintaining a connectivity pattern that is compatible with said convolution.

Current implementations of Deconvolution are Conv1dTranspose with input dimension of LayoutType.WIDTH and Conv2dTranspose with input dimension of LayoutType.WIDTH and LayoutType.HEIGHT. These implementations share the same core principal as a Deconvolution layer does, with the difference being the number of input dimension each operates on as denoted by ConvXdTranspose for X dimension(s).

  • Field Details

    • kernelShape

      protected Shape kernelShape
    • stride

      protected Shape stride
    • padding

      protected Shape padding
    • outPadding

      protected Shape outPadding
    • dilation

      protected Shape dilation
    • filters

      protected int filters
    • groups

      protected int groups
    • includeBias

      protected boolean includeBias
    • weight

      protected Parameter weight
    • bias

      protected Parameter bias
  • Constructor Details

  • Method Details

    • getExpectedLayout

      protected abstract LayoutType[] getExpectedLayout()
      Returns the expected layout of the input.
      Returns:
      the expected layout of the input
    • getStringLayout

      protected abstract String getStringLayout()
      Returns the string representing the layout of the input.
      Returns:
      the string representing the layout of the input
    • numDimensions

      protected abstract int numDimensions()
      Returns the number of dimensions of the input.
      Returns:
      the number of dimensions of the input
    • forwardInternal

      protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • beforeInitialize

      protected void beforeInitialize(Shape... inputShapes)
      Performs any action necessary before initialization. For example, keep the input information or verify the layout.
      Overrides:
      beforeInitialize in class AbstractBaseBlock
      Parameters:
      inputShapes - the expected shapes of the input
    • prepare

      protected void prepare(Shape[] inputs)
      Sets the shape of Parameters.
      Overrides:
      prepare in class AbstractBaseBlock
      Parameters:
      inputs - the shapes of inputs
    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputs)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputs - the shapes of the inputs
      Returns:
      the expected output shapes of the block
    • loadMetadata

      public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException
      Overwrite this to load additional metadata with the parameter values.

      If you overwrite AbstractBaseBlock.saveMetadata(DataOutputStream) or need to provide backward compatibility to older binary formats, you probably need to overwrite this. This default implementation checks if the version number fits, if not it throws an MalformedModelException. After that it restores the input shapes.

      Overrides:
      loadMetadata in class AbstractBaseBlock
      Parameters:
      loadVersion - the version used for loading this metadata.
      is - the input stream we are loading from
      Throws:
      IOException - loading failed
      MalformedModelException - data can be loaded but has wrong format