Package ai.djl.nn
Class AbstractBaseBlock
java.lang.Object
ai.djl.nn.AbstractBaseBlock
- All Implemented Interfaces:
Block
- Direct Known Subclasses:
AbstractBlock
,AbstractSymbolBlock
This provides shared functionality for both the DJL-based
AbstractBlock
s and the imported
AbstractSymbolBlock
s.-
Field Summary
FieldsModifier and TypeFieldDescriptionList of names for the input, named inputs should be manually set in sub class.protected Shape[]
The shape of the input for this block, set by the initialization process.protected DataType[]
protected byte
The model version of this block, used for checking if parameters are still valid during parameter loading. -
Constructor Summary
ConstructorsConstructorDescriptionConstructs a newAbstractBaseBlock
instance.AbstractBaseBlock
(byte version) Builds an empty block with the given version for parameter serialization. -
Method Summary
Modifier and TypeMethodDescriptionprotected void
beforeInitialize
(Shape... inputShapes) Performs any action necessary before initialization.void
Guaranteed to throw an exception.void
clear()
Closes all the parameters of the block.Returns aPairList
of input names, and shapes.final NDList
forward
(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) Applies the operating function of the block once.forward
(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<String, Object> params) A forward call using both training data and labels.protected abstract NDList
forwardInternal
(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.protected NDList
forwardInternal
(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, NDList, PairList)
after initialization.Shape[]
Returns the input shapes of the block.DataType[]
Returns the input dataTypes of the block.Returns a list of all the parameters of the block, including the parameters of its children fetched recursively.void
initialize
(NDManager manager, DataType dataType, Shape... inputShapes) Initializes the parameters of the block, set require gradient if required and infer the block inputShape.protected void
initializeChildBlocks
(NDManager manager, DataType dataType, Shape... inputShapes) Initializes the Child blocks of this block.boolean
Returns a boolean whether the block is initialized (block has inputShape and params have nonNull array).protected void
loadMetadata
(byte loadVersion, DataInputStream is) Overwrite this to load additional metadata with the parameter values.void
loadParameters
(NDManager manager, DataInputStream is) Loads the parameters from the given input stream.protected void
Sets the shape ofParameter
s.protected void
protected void
protected void
Override this method to save additional data apart from parameter values.void
Writes the parameters of the block to the given outputStream.void
setInitializer
(Initializer initializer, Parameter.Type params) Sets anInitializer
to all the parameters that match parameter type in the block.void
setInitializer
(Initializer initializer, String paramName) Sets anInitializer
to the specified direct parameter of the block, overriding the initializer of the parameter, if already set.void
setInitializer
(Initializer initializer, Predicate<Parameter> predicate) Sets anInitializer
to all the parameters that match Predicate in the block.toString()
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getChildren, getDirectParameters, getOutputShapes, getOutputShapes
-
Field Details
-
version
protected byte versionThe model version of this block, used for checking if parameters are still valid during parameter loading. -
inputShapes
The shape of the input for this block, set by the initialization process. -
outputDataTypes
-
inputNames
List of names for the input, named inputs should be manually set in sub class.
-
-
Constructor Details
-
AbstractBaseBlock
public AbstractBaseBlock()Constructs a newAbstractBaseBlock
instance. -
AbstractBaseBlock
public AbstractBaseBlock(byte version) Builds an empty block with the given version for parameter serialization.- Parameters:
version
- the version to use for parameter serialization.
-
-
Method Details
-
forward
public final NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) Applies the operating function of the block once. This method should be called only on blocks that are initialized. -
forward
public NDList forward(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<String, Object> params) A forward call using both training data and labels.Within this forward call, it can be assumed that training is true.
-
forwardInternal
protected abstract NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList data, NDList labels, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, NDList, PairList)
after initialization.- Parameters:
parameterStore
- the parameter storedata
- the input data NDListlabels
- the input labels NDListparams
- optional parameters- Returns:
- the output of the forward pass
- See Also:
-
describeInput
Returns aPairList
of input names, and shapes.- Specified by:
describeInput
in interfaceBlock
- Returns:
- the
PairList
of input names, and shapes
-
setInitializer
Sets anInitializer
to all the parameters that match parameter type in the block.- Specified by:
setInitializer
in interfaceBlock
- Parameters:
initializer
- the initializer to setparams
- the Parameter Type we want to setInitializer
-
setInitializer
Sets anInitializer
to the specified direct parameter of the block, overriding the initializer of the parameter, if already set.- Specified by:
setInitializer
in interfaceBlock
- Parameters:
initializer
- the initializer to be setparamName
- the name of the parameter
-
setInitializer
Sets anInitializer
to all the parameters that match Predicate in the block.- Specified by:
setInitializer
in interfaceBlock
- Parameters:
initializer
- the initializer to be setpredicate
- predicate function to indicate parameters you want to set
-
initialize
Initializes the parameters of the block, set require gradient if required and infer the block inputShape. This method must be called before calling `forward`.- Specified by:
initialize
in interfaceBlock
- Parameters:
manager
- the NDManager to initialize the parametersdataType
- the datatype of the parametersinputShapes
- the shapes of the inputs to the block
-
beforeInitialize
Performs any action necessary before initialization. For example, keep the input information or verify the layout.- Parameters:
inputShapes
- the expected shapes of the input
-
initializeChildBlocks
Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.- Parameters:
manager
- the manager to use for initializationdataType
- the requested data typeinputShapes
- the expected input shapes for this block
-
prepare
Sets the shape ofParameter
s.- Parameters:
inputShapes
- the shapes of inputs
-
getParameters
Returns a list of all the parameters of the block, including the parameters of its children fetched recursively.- Specified by:
getParameters
in interfaceBlock
- Returns:
- the list of all parameters of the block
-
isInitialized
public boolean isInitialized()Returns a boolean whether the block is initialized (block has inputShape and params have nonNull array).- Specified by:
isInitialized
in interfaceBlock
- Returns:
- whether the block is initialized
-
clear
public void clear()Closes all the parameters of the block. All the updates made during training will be lost. -
cast
Guaranteed to throw an exception. Not yet implemented -
saveParameters
Writes the parameters of the block to the given outputStream.- Specified by:
saveParameters
in interfaceBlock
- Parameters:
os
- the outputstream to save the parameters to- Throws:
IOException
- if an I/O error occurs
-
loadParameters
public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException Loads the parameters from the given input stream.- Specified by:
loadParameters
in interfaceBlock
- Parameters:
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter values- Throws:
IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupported
-
saveMetadata
Override this method to save additional data apart from parameter values.This default implementation saves the currently set input shapes.
- Parameters:
os
- the non-null output stream the parameter values and metadata are written to- Throws:
IOException
- saving failed
-
loadMetadata
protected void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException Overwrite this to load additional metadata with the parameter values.If you overwrite
saveMetadata(DataOutputStream)
or need to provide backward compatibility to older binary formats, you probably need to overwrite this. This default implementation checks if the version number fits, if not it throws anMalformedModelException
. After that it restores the input shapes.- Parameters:
loadVersion
- the version used for loading this metadata.is
- the input stream we are loading from- Throws:
IOException
- loading failedMalformedModelException
- data can be loaded but has wrong format
-
saveInputShapes
- Throws:
IOException
-
readInputShapes
- Throws:
IOException
-
toString
-
getInputShapes
Returns the input shapes of the block. The input shapes are only available after the block is initialized, otherwise anIllegalStateException
is thrown.- Specified by:
getInputShapes
in interfaceBlock
- Returns:
- the input shapes of the block
-
getOutputDataTypes
Returns the input dataTypes of the block.- Specified by:
getOutputDataTypes
in interfaceBlock
- Returns:
- the input dataTypes of the block
-