public abstract class AbstractBlock extends java.lang.Object implements Block
AbstractBlock
is an abstract implementation of Block
.
It is recommended that all Block
classes that have children extend the AbstractBlock
.
To create your own blocks, you need to do the following:
addParameter(Parameter, Shape)
or addParameter(Parameter, Function)
to add parameters to your block in the
constructor if necessary.
addChildBlock(String, Block)
to add child blocks if necessary.
Block.getOutputShapes(NDManager, Shape[])
to determine the shape
of your custom block's output based on the input it will receive.
initializeChildBlocks(NDManager, DataType, Shape...)
if you
added child blocks to initialize them based on the input shape your block will receive. You
can skip this if your block does not contain child blocks
Block.forward(ParameterStore, NDList, boolean, PairList)
to
implement the computation of your block
saveMetadata(DataOutputStream)
and loadMetadata(byte, DataInputStream)
. If you do not need to save or load any
state other than parameters in your block, you can skip this.
If you use addParameter(Parameter)
to add parameters, you have to take
care of parameter initialization yourself. In this case, you need to override getParameterShape(String, Shape[])
to determine the shape of your parameters. If
you use the other variants of addParameter
this is done for you.
Modifier and Type | Field and Description |
---|---|
protected BlockList |
children
All direct children of this Block.
|
protected java.util.List<java.lang.String> |
inputNames
List of names for the input, defaults to ["data"] unless manually changed.
|
protected Shape[] |
inputShapes
The shape of the input for this block, set by the initialization process.
|
protected java.util.LinkedHashMap<java.lang.String,Parameter> |
parameters
All direct parameters of this Block.
|
protected java.util.LinkedHashMap<java.lang.String,java.util.function.Function<Shape[],Shape>> |
parameterShapeCallbacks
Callbacks to determine the shape of a parameter.
|
protected byte |
version
The model version of this block, used for checking if parameters are still valid during
parameter loading.
|
Constructor and Description |
---|
AbstractBlock(byte version)
Builds an empty block with the given version for parameter serialization.
|
Modifier and Type | Method and Description |
---|---|
protected <B extends Block> |
addChildBlock(java.lang.String name,
B block)
Use this to add a child block to this block.
|
protected <P extends Parameter> |
addParameter(P parameter)
Adds a parameter to this block.
|
protected <P extends Parameter> |
addParameter(P parameter,
java.util.function.Function<Shape[],Shape> shapeCallback)
Adds a parameter to this block.
|
protected <P extends Parameter> |
addParameter(P parameter,
Shape shape)
Adds a parameter to this block.
|
protected void |
beforeInitialize(Shape[] inputShapes)
Performs any action necessary before initialization.
|
void |
cast(DataType dataType)
Guaranteed to throw an exception.
|
void |
clear()
Closes all the parameters of the block.
|
ai.djl.util.PairList<java.lang.String,Shape> |
describeInput()
Returns a
PairList of input names, and shapes. |
BlockList |
getChildren()
Returns a list of all the children of the block.
|
ParameterList |
getDirectParameters()
Returns a list of all the direct parameters of the block.
|
ParameterList |
getParameters()
Returns a list of all the parameters of the block, including the parameters of its children
fetched recursively.
|
Shape |
getParameterShape(java.lang.String name,
Shape[] inputShapes)
Returns the shape of the specified direct parameter of this block given the shapes of the
input to the block.
|
Shape[] |
initialize(NDManager manager,
DataType dataType,
Shape... inputShapes)
Initializes the parameters of the block.
|
protected void |
initializeChildBlocks(NDManager manager,
DataType dataType,
Shape... inputShapes)
Initializes the Child blocks of this block.
|
boolean |
isInitialized()
Returns a boolean whether the block is initialized.
|
protected void |
loadMetadata(byte loadVersion,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
void |
loadParameters(NDManager manager,
java.io.DataInputStream is)
Loads the parameters from the given input stream.
|
protected void |
readInputShapes(java.io.DataInputStream is) |
protected void |
saveInputShapes(java.io.DataOutputStream os) |
protected void |
saveMetadata(java.io.DataOutputStream os)
Override this method to save additional data apart from parameter values.
|
void |
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.
|
void |
setInitializer(Initializer initializer)
Sets an
Initializer to the block. |
void |
setInitializer(Initializer initializer,
java.lang.String paramName)
Sets an
Initializer to the specified direct parameter of the block, overriding the
initializer of the parameter, if already set. |
java.lang.String |
toString() |
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, forward, forward, getOutputShapes, validateLayout
protected Shape[] inputShapes
protected java.util.List<java.lang.String> inputNames
protected byte version
protected BlockList children
Use the addChildBlock(String, Block)
method to add children. All
children in this map are automagically loaded / saved.
protected java.util.LinkedHashMap<java.lang.String,Parameter> parameters
Use the addParameter(Parameter)
method to add children. All
parameters in this map are automatically loaded / saved.
protected java.util.LinkedHashMap<java.lang.String,java.util.function.Function<Shape[],Shape>> parameterShapeCallbacks
Block.getParameterShape(String, Shape[])
and implement
parameter shape resolution manually.public AbstractBlock(byte version)
version
- the version to use for parameter serialization.protected final <B extends Block> B addChildBlock(java.lang.String name, B block)
B
- The type of blockname
- Name of the block, must be unique or otherwise existing children with this name
are removed, must not be null.block
- The block, must not be null.protected final <P extends Parameter> P addParameter(P parameter)
Block.getParameterShape(String, Shape[])
and return the shapes of parameters
themselves.P
- the specific parameter subclassparameter
- the parameter to add, not nullprotected final <P extends Parameter> P addParameter(P parameter, Shape shape)
P
- the specific parameter subclassparameter
- the parameter to add, not nullshape
- the shape of the parameterprotected final <P extends Parameter> P addParameter(P parameter, java.util.function.Function<Shape[],Shape> shapeCallback)
P
- the specific parameter subclassparameter
- the parameter to add, not nullshapeCallback
- the method to call once the input shape of this block is known to
determine the shape of the given parameterpublic Shape getParameterShape(java.lang.String name, Shape[] inputShapes)
getParameterShape
in interface Block
name
- the name of the parameterinputShapes
- the shapes of the input to the blockpublic BlockList getChildren()
getChildren
in interface Block
public ai.djl.util.PairList<java.lang.String,Shape> describeInput()
PairList
of input names, and shapes.describeInput
in interface Block
PairList
of input names, and shapespublic void setInitializer(Initializer initializer)
Initializer
to the block.setInitializer
in interface Block
initializer
- the initializer to setpublic void setInitializer(Initializer initializer, java.lang.String paramName)
Initializer
to the specified direct parameter of the block, overriding the
initializer of the parameter, if already set.setInitializer
in interface Block
initializer
- the initializer to be setparamName
- the name of the parameterpublic Shape[] initialize(NDManager manager, DataType dataType, Shape... inputShapes)
initialize
in interface Block
manager
- the NDManager to initialize the parametersdataType
- the datatype of the parametersinputShapes
- the shapes of the inputs to the blockprotected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
manager
- the manager to use for initializationdataType
- the requested data typeinputShapes
- the expected input shapes for this blockpublic ParameterList getParameters()
getParameters
in interface Block
public ParameterList getDirectParameters()
getDirectParameters
in interface Block
Parameter
protected void beforeInitialize(Shape[] inputShapes)
inputShapes
- the expected shapes of the inputpublic boolean isInitialized()
isInitialized
in interface Block
public void clear()
public void cast(DataType dataType)
public void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
saveParameters
in interface Block
os
- the outputstream to save the parameters tojava.io.IOException
- if an I/O error occurspublic void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
loadParameters
in interface Block
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter valuesjava.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupportedprotected void saveMetadata(java.io.DataOutputStream os) throws java.io.IOException
This default implementation saves the currently set input shapes.
os
- the non-null output stream the parameter values and metadata are written tojava.io.IOException
- saving failedprotected void loadMetadata(byte loadVersion, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
If you overwrite saveMetadata(DataOutputStream)
or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException
. After that it restores the input shapes.
loadVersion
- the version used for loading this metadata.is
- the input stream we are loading fromjava.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong formatprotected void saveInputShapes(java.io.DataOutputStream os) throws java.io.IOException
java.io.IOException
protected void readInputShapes(java.io.DataInputStream is) throws java.io.IOException
java.io.IOException
public java.lang.String toString()
toString
in class java.lang.Object