Package ai.djl.nn

Class AbstractBaseBlock

java.lang.Object
ai.djl.nn.AbstractBaseBlock
All Implemented Interfaces:
Block
Direct Known Subclasses:
AbstractBlock, AbstractSymbolBlock

public abstract class AbstractBaseBlock extends Object implements Block
This provides shared functionality for both the DJL-based AbstractBlocks and the imported AbstractSymbolBlocks.
  • Field Details

    • version

      protected byte version
      The model version of this block, used for checking if parameters are still valid during parameter loading.
    • inputShapes

      protected Shape[] inputShapes
      The shape of the input for this block, set by the initialization process.
    • outputDataTypes

      protected DataType[] outputDataTypes
    • inputNames

      protected List<String> inputNames
      List of names for the input, named inputs should be manually set in sub class.
  • Constructor Details

    • AbstractBaseBlock

      public AbstractBaseBlock()
      Constructs a new AbstractBaseBlock instance.
    • AbstractBaseBlock

      public AbstractBaseBlock(byte version)
      Builds an empty block with the given version for parameter serialization.
      Parameters:
      version - the version to use for parameter serialization.
  • Method Details

    • forward

      public final NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Applies the operating function of the block once. This method should be called only on blocks that are initialized.
      Specified by:
      forward in interface Block
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass (turn on dropout and layerNorm)
      params - optional parameters
      Returns:
      the output of the forward pass
    • forward

      public NDList forward(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<String,Object> params)
      A forward call using both training data and labels.

      Within this forward call, it can be assumed that training is true.

      Specified by:
      forward in interface Block
      Parameters:
      parameterStore - the parameter store
      data - the input data NDList
      labels - the input labels NDList
      params - optional parameters
      Returns:
      the output of the forward pass
      See Also:
    • forwardInternal

      protected abstract NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • forwardInternal

      protected NDList forwardInternal(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<String,Object> params)
      Parameters:
      parameterStore - the parameter store
      data - the input data NDList
      labels - the input labels NDList
      params - optional parameters
      Returns:
      the output of the forward pass
      See Also:
    • describeInput

      public ai.djl.util.PairList<String,Shape> describeInput()
      Returns a PairList of input names, and shapes.
      Specified by:
      describeInput in interface Block
      Returns:
      the PairList of input names, and shapes
    • setInitializer

      public void setInitializer(Initializer initializer, Parameter.Type params)
      Sets an Initializer to all the parameters that match parameter type in the block.
      Specified by:
      setInitializer in interface Block
      Parameters:
      initializer - the initializer to set
      params - the Parameter Type we want to setInitializer
    • setInitializer

      public void setInitializer(Initializer initializer, String paramName)
      Sets an Initializer to the specified direct parameter of the block, overriding the initializer of the parameter, if already set.
      Specified by:
      setInitializer in interface Block
      Parameters:
      initializer - the initializer to be set
      paramName - the name of the parameter
    • setInitializer

      public void setInitializer(Initializer initializer, Predicate<Parameter> predicate)
      Sets an Initializer to all the parameters that match Predicate in the block.
      Specified by:
      setInitializer in interface Block
      Parameters:
      initializer - the initializer to be set
      predicate - predicate function to indicate parameters you want to set
    • initialize

      public void initialize(NDManager manager, DataType dataType, Shape... inputShapes)
      Initializes the parameters of the block, set require gradient if required and infer the block inputShape. This method must be called before calling `forward`.
      Specified by:
      initialize in interface Block
      Parameters:
      manager - the NDManager to initialize the parameters
      dataType - the datatype of the parameters
      inputShapes - the shapes of the inputs to the block
    • beforeInitialize

      protected void beforeInitialize(Shape... inputShapes)
      Performs any action necessary before initialization. For example, keep the input information or verify the layout.
      Parameters:
      inputShapes - the expected shapes of the input
    • initializeChildBlocks

      protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
      Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
      Parameters:
      manager - the manager to use for initialization
      dataType - the requested data type
      inputShapes - the expected input shapes for this block
    • prepare

      protected void prepare(Shape[] inputShapes)
      Sets the shape of Parameters.
      Parameters:
      inputShapes - the shapes of inputs
    • getParameters

      public ParameterList getParameters()
      Returns a list of all the parameters of the block, including the parameters of its children fetched recursively.
      Specified by:
      getParameters in interface Block
      Returns:
      the list of all parameters of the block
    • isInitialized

      public boolean isInitialized()
      Returns a boolean whether the block is initialized (block has inputShape and params have nonNull array).
      Specified by:
      isInitialized in interface Block
      Returns:
      whether the block is initialized
    • clear

      public void clear()
      Closes all the parameters of the block. All the updates made during training will be lost.
      Specified by:
      clear in interface Block
    • cast

      public void cast(DataType dataType)
      Guaranteed to throw an exception. Not yet implemented
      Specified by:
      cast in interface Block
      Parameters:
      dataType - the data type to cast to
    • saveParameters

      public void saveParameters(DataOutputStream os) throws IOException
      Writes the parameters of the block to the given outputStream.
      Specified by:
      saveParameters in interface Block
      Parameters:
      os - the outputstream to save the parameters to
      Throws:
      IOException - if an I/O error occurs
    • loadParameters

      public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException
      Loads the parameters from the given input stream.
      Specified by:
      loadParameters in interface Block
      Parameters:
      manager - an NDManager to create the parameter arrays
      is - the inputstream that stream the parameter values
      Throws:
      IOException - if an I/O error occurs
      MalformedModelException - if the model file is corrupted or unsupported
    • saveMetadata

      protected void saveMetadata(DataOutputStream os) throws IOException
      Override this method to save additional data apart from parameter values.

      This default implementation saves the currently set input shapes.

      Parameters:
      os - the non-null output stream the parameter values and metadata are written to
      Throws:
      IOException - saving failed
    • loadMetadata

      protected void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException
      Overwrite this to load additional metadata with the parameter values.

      If you overwrite saveMetadata(DataOutputStream) or need to provide backward compatibility to older binary formats, you probably need to overwrite this. This default implementation checks if the version number fits, if not it throws an MalformedModelException. After that it restores the input shapes.

      Parameters:
      loadVersion - the version used for loading this metadata.
      is - the input stream we are loading from
      Throws:
      IOException - loading failed
      MalformedModelException - data can be loaded but has wrong format
    • saveInputShapes

      protected void saveInputShapes(DataOutputStream os) throws IOException
      Throws:
      IOException
    • readInputShapes

      protected void readInputShapes(DataInputStream is) throws IOException
      Throws:
      IOException
    • toString

      public String toString()
      Overrides:
      toString in class Object
    • getInputShapes

      public Shape[] getInputShapes()
      Returns the input shapes of the block. The input shapes are only available after the block is initialized, otherwise an IllegalStateException is thrown.
      Specified by:
      getInputShapes in interface Block
      Returns:
      the input shapes of the block
    • getOutputDataTypes

      public DataType[] getOutputDataTypes()
      Returns the input dataTypes of the block.
      Specified by:
      getOutputDataTypes in interface Block
      Returns:
      the input dataTypes of the block