public class ParallelBlock extends AbstractBlock
ParallelBlock
is a Block
whose children form a parallel branch in the network and
are combined to produce a single output.
ParallelBlock
has no direct parameters.
children, inputNames, inputShapes, parameters, parameterShapeCallbacks, version
Constructor and Description |
---|
ParallelBlock(java.util.function.Function<java.util.List<NDList>,NDList> function)
Creates a parallel block whose branches are combined to form a single output by the given
function.
|
ParallelBlock(java.util.function.Function<java.util.List<NDList>,NDList> function,
java.util.List<Block> blocks)
Creates a parallel block whose branches are formed by each block in the list of blocks, and
are combined to form a single output by the given function.
|
Modifier and Type | Method and Description |
---|---|
ParallelBlock |
add(Block block)
Adds the given
Block to the block, which is one parallel branch. |
ParallelBlock |
add(java.util.function.Function<NDList,NDList> f)
Adds a
LambdaBlock , that applies the given function, to the list of parallel
branches. |
ParallelBlock |
addAll(Block... blocks)
Adds an array of blocks, each of which is a parallel branch.
|
ParallelBlock |
addAll(java.util.Collection<Block> blocks)
Adds a
Collection of blocks, each of which is a parallel branch. |
protected NDList |
forwardInternal(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params) |
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.
|
void |
initializeChildBlocks(NDManager manager,
DataType dataType,
Shape... inputShapes)
Initializes the Child blocks of this block.
|
void |
loadMetadata(byte version,
java.io.DataInputStream is)
Overwrite this to load additional metadata with the parameter values.
|
java.lang.String |
toString() |
addChildBlock, addParameter, addParameter, addParameter, beforeInitialize, cast, clear, describeInput, forward, getChildren, getDirectParameters, getParameters, getParameterShape, initialize, isInitialized, loadParameters, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, forward, validateLayout
public ParallelBlock(java.util.function.Function<java.util.List<NDList>,NDList> function)
function
- the function to define how the parallel branches are combined to form a
single outputpublic ParallelBlock(java.util.function.Function<java.util.List<NDList>,NDList> function, java.util.List<Block> blocks)
function
- the function to define how the parallel branches are combinedblocks
- the blocks that form each of the parallel branchespublic final ParallelBlock addAll(Block... blocks)
blocks
- the array of blocks to addpublic final ParallelBlock addAll(java.util.Collection<Block> blocks)
Collection
of blocks, each of which is a parallel branch.blocks
- the Collection
of blocks to addpublic final ParallelBlock add(Block block)
Block
to the block, which is one parallel branch.block
- the block to be added as a parallel branchpublic final ParallelBlock add(java.util.function.Function<NDList,NDList> f)
LambdaBlock
, that applies the given function, to the list of parallel
branches.f
- the function that forms the LambdaBlock
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
forwardInternal
in class AbstractBlock
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
initializeChildBlocks
in class AbstractBlock
manager
- the manager to use for initializationdataType
- the requested data typeinputShapes
- the expected input shapes for this blockpublic Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes)
manager
- an NDManagerinputShapes
- the shapes of the inputspublic void loadMetadata(byte version, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
If you overwrite AbstractBlock.saveMetadata(DataOutputStream)
or need to provide
backward compatibility to older binary formats, you prabably need to overwrite this. This
default implementation checks if the version number fits, if not it throws an MalformedModelException
. After that it restores the input shapes.
loadMetadata
in class AbstractBlock
version
- the version used for loading this metadata.is
- the input stream we are loading fromjava.io.IOException
- loading failedMalformedModelException
- data can be loaded but has wrong formatpublic java.lang.String toString()
toString
in class AbstractBlock