Class Linear

All Implemented Interfaces:
Block

public class Linear extends AbstractBlock
A Linear block applies a linear transformation \(Y = XW^T + b\).

It has the following shapes:

  • input X: [x1, x2, ..., xn, input_dim]
  • weight W: [units, input_dim]
  • Bias b: [units]
  • output Y: [x1, x2, ..., xn, units]

It is most typically used with a simple batched 1D input. In that case, the shape would be:

  • input X: [batch_num, input_dim]
  • weight W: [units, input_dim]
  • Bias b: [units]
  • output Y: [batch_num, units]

The Linear block should be constructed using Linear.Builder.

  • Constructor Details

  • Method Details

    • forwardInternal

      protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputs)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputs - the shapes of the inputs
      Returns:
      the expected output shapes of the block
    • describeInput

      public ai.djl.util.PairList<String,Shape> describeInput()
      Returns a PairList of input names, and shapes.
      Specified by:
      describeInput in interface Block
      Overrides:
      describeInput in class AbstractBaseBlock
      Returns:
      the PairList of input names, and shapes
    • beforeInitialize

      protected void beforeInitialize(Shape... inputShapes)
      Performs any action necessary before initialization. For example, keep the input information or verify the layout.
      Overrides:
      beforeInitialize in class AbstractBaseBlock
      Parameters:
      inputShapes - the expected shapes of the input
    • prepare

      public void prepare(Shape[] inputShapes)
      Sets the shape of Parameters.
      Overrides:
      prepare in class AbstractBaseBlock
      Parameters:
      inputShapes - the shapes of inputs
    • saveMetadata

      protected void saveMetadata(DataOutputStream os) throws IOException
      Override this method to save additional data apart from parameter values.

      This default implementation saves the currently set input shapes.

      Overrides:
      saveMetadata in class AbstractBaseBlock
      Parameters:
      os - the non-null output stream the parameter values and metadata are written to
      Throws:
      IOException - saving failed
    • loadMetadata

      public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException
      Overwrite this to load additional metadata with the parameter values.

      If you overwrite AbstractBaseBlock.saveMetadata(DataOutputStream) or need to provide backward compatibility to older binary formats, you probably need to overwrite this. This default implementation checks if the version number fits, if not it throws an MalformedModelException. After that it restores the input shapes.

      Overrides:
      loadMetadata in class AbstractBaseBlock
      Parameters:
      loadVersion - the version used for loading this metadata.
      is - the input stream we are loading from
      Throws:
      IOException - loading failed
      MalformedModelException - data can be loaded but has wrong format
    • linear

      public static NDList linear(NDArray input, NDArray weight)
      Applies a linear transformation to the incoming data.
      Parameters:
      input - input X: [x1, x2, …, xn, input_dim]
      weight - weight W: [units, input_dim]
      Returns:
      output Y: [x1, x2, …, xn, units]
    • linear

      public static NDList linear(NDArray input, NDArray weight, NDArray bias)
      Applies a linear transformation to the incoming data.
      Parameters:
      input - input X: [x1, x2, …, xn, input_dim]
      weight - weight W: [units, input_dim]
      bias - bias b: [units]
      Returns:
      output Y: [x1, x2, …, xn, units]
    • builder

      public static Linear.Builder builder()
      Creates a builder to build a Linear.
      Returns:
      a new builder