public class ParallelBlock extends AbstractBlock
ParallelBlock is a Block whose children form a parallel branch in the network and
are combined to produce a single output.
ParallelBlock has no direct parameters.
children, inputNames, inputShapes, parameters, parameterShapeCallbacks, version| Constructor and Description |
|---|
ParallelBlock(java.util.function.Function<java.util.List<NDList>,NDList> function)
Creates a parallel block whose branches are combined to form a single output by the given
function.
|
ParallelBlock(java.util.function.Function<java.util.List<NDList>,NDList> function,
java.util.List<Block> blocks)
Creates a parallel block whose branches are formed by each block in the list of blocks, and
are combined to form a single output by the given function.
|
| Modifier and Type | Method and Description |
|---|---|
ParallelBlock |
add(Block block)
Adds the given
Block to the block, which is one parallel branch. |
ParallelBlock |
add(java.util.function.Function<NDList,NDList> f)
Adds a
LambdaBlock, that applies the given function, to the list of parallel
branches. |
ParallelBlock |
addAll(Block... blocks)
Adds an array of blocks, each of which is a parallel branch.
|
ParallelBlock |
addAll(java.util.Collection<Block> blocks)
Adds a
Collection of blocks, each of which is a parallel branch. |
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.
|
void |
initializeChildBlocks(NDManager manager,
DataType dataType,
Shape... inputShapes)
Initializes the Child blocks of this block.
|
void |
loadMetadata(byte version,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
java.lang.String |
toString() |
addChildBlock, addParameter, addParameter, addParameter, beforeInitialize, cast, clear, describeInput, getChildren, getDirectParameters, getParameters, getParameterShape, initialize, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializerclone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitforward, forward, validateLayoutpublic ParallelBlock(java.util.function.Function<java.util.List<NDList>,NDList> function)
function - the function to define how the parallel branches are combined to form a
single outputpublic ParallelBlock(java.util.function.Function<java.util.List<NDList>,NDList> function, java.util.List<Block> blocks)
function - the function to define how the parallel branches are combinedblocks - the blocks that form each of the parallel branchespublic final ParallelBlock addAll(Block... blocks)
blocks - the array of blocks to addpublic final ParallelBlock addAll(java.util.Collection<Block> blocks)
Collection of blocks, each of which is a parallel branch.blocks - the Collection of blocks to addpublic final ParallelBlock add(Block block)
Block to the block, which is one parallel branch.block - the block to be added as a parallel branchpublic final ParallelBlock add(java.util.function.Function<NDList,NDList> f)
LambdaBlock, that applies the given function, to the list of parallel
branches.f - the function that forms the LambdaBlockpublic NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
parameterStore - the parameter storeinputs - the input NDListtraining - true for a training forward passparams - optional parameterspublic void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
initializeChildBlocks in class AbstractBlockmanager - the manager to use for initializationdataType - the requested data typeinputShapes - the expected input shapes for this blockpublic Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes)
manager - an NDManagerinputShapes - the shapes of the inputspublic void loadMetadata(byte version,
java.io.DataInputStream is)
throws java.io.IOException,
MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream) or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException. After that it restores the input shapes.
loadMetadata in class AbstractBlockversion - the version used for loading this metadata.is - the input stream we are loading fromjava.io.IOException - loading failedMalformedModelException - data can be loaded but has wrong formatpublic java.lang.String toString()
toString in class AbstractBlock