Package ai.djl.nn

Class Parameter

java.lang.Object
ai.djl.nn.Parameter
All Implemented Interfaces:
AutoCloseable

public class Parameter extends Object implements AutoCloseable
Parameter is a container class that holds a learnable parameter of a model.

Every Parameter is associated with a Block. The output of the block's forward function depends on the values in the Parameter. During training, the values in the Parameter are updated to reflect the training data. This process forms the crux of learning.

See Also:
  • Method Details

    • getId

      public String getId()
      Gets the ID of this Parameter.
      Returns:
      the ID of this Parameter
    • getName

      public String getName()
      Gets the name of this Parameter.
      Returns:
      the name of this Parameter
    • getType

      public Parameter.Type getType()
      Gets the type of this Parameter.
      Returns:
      the type of this Parameter
    • setArray

      public void setArray(NDArray array)
      Sets the values of this Parameter.
      Parameters:
      array - the NDArray that contains values of this Parameter
    • setShape

      public void setShape(Shape shape)
      Sets the shape of this Parameter.
      Parameters:
      shape - the shape of this Parameter
    • getShape

      public Shape getShape()
      Gets the shape of this Parameter.
      Returns:
      the shape of this Parameter
    • getArray

      public NDArray getArray()
      Gets the values of this Parameter as an NDArray.
      Returns:
      an NDArray that contains values of this Parameter
    • requiresGradient

      public boolean requiresGradient()
      Returns whether this parameter needs gradients to be computed.
      Returns:
      whether this parameter needs gradients to be computed
    • freeze

      public void freeze(boolean freeze)
      Freezes or unfreezes the parameter for training.

      Sometimes during training, especially during transfer learning, it is typical to train only part of the model. For this, the freeze can be used to prevent certain parts from being trained.

      This modifies the requiresGradient() of the parameter.

      Parameters:
      freeze - true if the parameter should be frozen (freeze == !requiresGradient())
    • isInitialized

      public boolean isInitialized()
      Checks if this Parameter is initialized.
      Returns:
      true if this Parameter is initialized
    • setInitializer

      public void setInitializer(Initializer initializer)
      Sets the Initializer for this Parameter, if not already set. If overwrite flag is true, sets the initializer regardless.
      Parameters:
      initializer - the initializer to be set
    • getInitializer

      public Initializer getInitializer()
      Returns the Initializer for this Parameter, if not already set. If overwrite flag is true, sets the initializer regardless.
      Returns:
      the initializer of this Parameter
    • initialize

      public void initialize(NDManager manager, DataType dataType)
      Initializes the parameter with the given NDManager, with given DataType for the given expected input shapes.
      Parameters:
      manager - an NDManager to create the arrays
      dataType - the datatype of the Parameter
    • save

      public void save(DataOutputStream dos) throws IOException
      Writes the parameter NDArrays to the given output stream.
      Parameters:
      dos - the output stream to write to
      Throws:
      IOException - if the write operation fails
    • load

      public void load(NDManager manager, DataInputStream dis) throws IOException, MalformedModelException
      Loads parameter NDArrays from InputStream.

      Currently, we cannot deserialize into the exact subclass of NDArray. The SparseNDArray will be loaded as NDArray only.

      Parameters:
      manager - the NDManager
      dis - the InputStream
      Throws:
      IOException - if failed to read
      MalformedModelException - Exception thrown when model is not in expected format (parameters).
    • close

      public void close()
      Specified by:
      close in interface AutoCloseable
    • builder

      public static Parameter.Builder builder()
      Creates a builder to build a Parameter.

      The methods start with set are required fields, and opt for optional fields.

      Returns:
      a new builder