public class LambdaBlock extends ParameterBlock
LambdaBlock
is a Block
with no parameters or children.
LambdaBlock
allows converting any function that takes an NDList
as input and
returns an NDList
into a parameter-less and child-less Block
. This can be useful
in converting activation functions, identity blocks, and more. For example, identity block can be
created as new LambdaBlock(x -> x)
.
inputNames, inputShapes
Constructor and Description |
---|
LambdaBlock(java.util.function.Function<NDList,NDList> lambda)
Creates a LambdaBlock that can apply the specified function.
|
Modifier and Type | Method and Description |
---|---|
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
java.util.List<Parameter> |
getDirectParameters()
Returns a list of all the direct parameters of the block.
|
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.
|
Shape |
getParameterShape(java.lang.String name,
Shape[] inputShapes)
Returns the shape of the specified direct parameter of this block given the shapes of the
input to the block.
|
void |
loadParameters(NDManager manager,
java.io.DataInputStream is)
Loads the parameters from the given input stream.
|
void |
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.
|
java.lang.String |
toString() |
getChildren, initialize
beforeInitialize, cast, clear, describeInput, getParameters, isInitialized, readInputShapes, saveInputShapes, setInitializer, setInitializer
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
forward, validateLayout
public NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameterspublic Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes)
manager
- an NDManagerinputShapes
- the shapes of the inputspublic java.util.List<Parameter> getDirectParameters()
Parameter
public Shape getParameterShape(java.lang.String name, Shape[] inputShapes)
name
- the name of the parameterinputShapes
- the shapes of the input to the blockpublic void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
os
- the outputstream to save the parameters tojava.io.IOException
- if an I/O error occurspublic void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter valuesjava.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupportedpublic java.lang.String toString()
toString
in class ParameterBlock