Class Convolution.ConvolutionBuilder<T extends Convolution.ConvolutionBuilder>

java.lang.Object
ai.djl.nn.convolutional.Convolution.ConvolutionBuilder<T>
Type Parameters:
T - the type of Convolution block to build
Direct Known Subclasses:
Conv1d.Builder, Conv2d.Builder, Conv3d.Builder
Enclosing class:
Convolution

public abstract static class Convolution.ConvolutionBuilder<T extends Convolution.ConvolutionBuilder> extends Object
A builder that can build any Convolution block.
  • Field Details

    • kernelShape

      protected Shape kernelShape
    • stride

      protected Shape stride
    • padding

      protected Shape padding
    • dilation

      protected Shape dilation
    • filters

      protected int filters
    • groups

      protected int groups
    • includeBias

      protected boolean includeBias
  • Constructor Details

    • ConvolutionBuilder

      public ConvolutionBuilder()
  • Method Details

    • setKernelShape

      public T setKernelShape(Shape kernelShape)
      Sets the shape of the kernel.
      Parameters:
      kernelShape - the shape of the kernel
      Returns:
      this Builder
    • optStride

      public T optStride(Shape stride)
      Sets the stride of the convolution. Defaults to 1 in each dimension.
      Parameters:
      stride - the shape of the stride
      Returns:
      this Builder
    • optPadding

      public T optPadding(Shape padding)
      Sets the padding along each dimension. Defaults to 0 along each dimension.
      Parameters:
      padding - the shape of padding along each dimension
      Returns:
      this Builder
    • optDilation

      public T optDilation(Shape dilate)
      Sets the dilation along each dimension. Defaults to 1 along each dimension.
      Parameters:
      dilate - the shape of dilation along each dimension
      Returns:
      this Builder
    • setFilters

      public T setFilters(int filters)
      Sets the Required number of filters.
      Parameters:
      filters - the number of convolution filters(channels)
      Returns:
      this Builder
    • optGroups

      public T optGroups(int groups)
      Sets the number of group partitions.
      Parameters:
      groups - the number of group partitions
      Returns:
      this Builder
    • optBias

      public T optBias(boolean includeBias)
      Sets the optional parameter of whether to include a bias vector. Includes bias by default.
      Parameters:
      includeBias - whether to use a bias vector parameter
      Returns:
      this Builder
    • validate

      protected void validate()
      Validates that the required arguments are set.
      Throws:
      IllegalArgumentException - if the required arguments are not set
    • self

      protected abstract T self()