Package ai.djl.nn
Class Parameter
- java.lang.Object
-
- ai.djl.nn.Parameter
-
- All Implemented Interfaces:
java.lang.AutoCloseable
public class Parameter extends java.lang.Object implements java.lang.AutoCloseable
Parameter
is a container class that holds a learnable parameter of a model.Every
Parameter
is associated with aBlock
. The output of the block's forward function depends on the values in theParameter
. During training, the values in theParameter
are updated to reflect the training data. This process forms the crux of learning.- See Also:
- The D2L chapter on parameter management
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
Parameter.Builder
A Builder to construct aParameter
.static class
Parameter.Type
Enumerates the types ofParameter
.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static Parameter.Builder
builder()
Creates a builder to build aParameter
.void
close()
NDArray
getArray()
Gets the values of thisParameter
as anNDArray
.java.lang.String
getId()
Gets the ID of thisParameter
.java.lang.String
getName()
Gets the name of thisParameter
.Parameter.Type
getType()
Gets the type of thisParameter
.void
initialize(NDManager manager, DataType dataType)
boolean
isInitialized()
Checks if thisParameter
is initialized.void
load(NDManager manager, java.io.DataInputStream dis)
Loads parameter NDArrays from InputStream.boolean
requiresGradient()
Returns whether this parameter needs gradients to be computed.void
save(java.io.DataOutputStream dos)
Writes the parameter NDArrays to the given output stream.void
setArray(NDArray array)
Sets the values of thisParameter
.void
setInitializer(Initializer initializer)
Sets theInitializer
for thisParameter
, if not already set.void
setShape(Shape shape)
Sets the shape of thisParameter
.
-
-
-
Method Detail
-
getId
public java.lang.String getId()
Gets the ID of thisParameter
.- Returns:
- the ID of this
Parameter
-
getName
public java.lang.String getName()
Gets the name of thisParameter
.- Returns:
- the name of this
Parameter
-
getType
public Parameter.Type getType()
Gets the type of thisParameter
.- Returns:
- the type of this
Parameter
-
setArray
public void setArray(NDArray array)
Sets the values of thisParameter
.- Parameters:
array
- theNDArray
that contains values of thisParameter
-
setShape
public void setShape(Shape shape)
Sets the shape of thisParameter
.- Parameters:
shape
- the shape of thisParameter
-
getArray
public NDArray getArray()
Gets the values of thisParameter
as anNDArray
.- Returns:
- an
NDArray
that contains values of thisParameter
-
requiresGradient
public boolean requiresGradient()
Returns whether this parameter needs gradients to be computed.- Returns:
- whether this parameter needs gradients to be computed
-
isInitialized
public boolean isInitialized()
Checks if thisParameter
is initialized.- Returns:
true
if thisParameter
is initialized
-
setInitializer
public void setInitializer(Initializer initializer)
Sets theInitializer
for thisParameter
, if not already set. If overwrite flag is true, sets the initializer regardless.- Parameters:
initializer
- the initializer to be set
-
initialize
public void initialize(NDManager manager, DataType dataType)
Initializes the parameter with the givenNDManager
, with givenDataType
for the given expected input shapes.- Parameters:
manager
- an NDManager to create the arraysdataType
- the datatype of theParameter
-
save
public void save(java.io.DataOutputStream dos) throws java.io.IOException
Writes the parameter NDArrays to the given output stream.- Parameters:
dos
- the output stream to write to- Throws:
java.io.IOException
- if the write operation fails
-
load
public void load(NDManager manager, java.io.DataInputStream dis) throws java.io.IOException, MalformedModelException
Loads parameter NDArrays from InputStream.Currently, we cannot deserialize into the exact subclass of NDArray. The SparseNDArray will be loaded as NDArray only.
- Parameters:
manager
- the NDManagerdis
- the InputStream- Throws:
java.io.IOException
- if failed to readMalformedModelException
- Exception thrown when model is not in expected format (parameters).
-
close
public void close()
- Specified by:
close
in interfacejava.lang.AutoCloseable
-
builder
public static Parameter.Builder builder()
Creates a builder to build aParameter
.The methods start with
set
are required fields, andopt
for optional fields.- Returns:
- a new builder
-
-