Package ai.djl.nn
Class LambdaBlock
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.LambdaBlock
-
- All Implemented Interfaces:
Block
public class LambdaBlock extends AbstractBlock
LambdaBlockis aBlockwith no parameters or children.LambdaBlockallows converting any function that takes anNDListas input and returns anNDListinto a parameter-less and child-lessBlock. This can be useful in converting activation functions, identity blocks, and more. For example, identity block can be created asnew LambdaBlock(x -> x).
-
-
Field Summary
Fields Modifier and Type Field Description static java.lang.StringDEFAULT_NAME-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
-
Constructor Summary
Constructors Constructor Description LambdaBlock(java.util.function.Function<NDList,NDList> lambda)Creates a LambdaBlock that can apply the specified function.LambdaBlock(java.util.function.Function<NDList,NDList> lambda, java.lang.String name)Creates a LambdaBlock that can apply the specified function.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description protected NDListforwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.java.lang.StringgetName()Returns the lambda function name.Shape[]getOutputShapes(Shape[] inputShapes)Returns the expected output shapes of the block for the specified input shapes.voidloadParameters(NDManager manager, java.io.DataInputStream is)Loads the parameters from the given input stream.static LambdaBlocksingleton(java.util.function.Function<NDArray,NDArray> lambda)Creates aLambdaBlockfor a singleton function.static LambdaBlocksingleton(java.util.function.Function<NDArray,NDArray> lambda, java.lang.String name)Creates aLambdaBlockfor a singleton function.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, getOutputShapes
-
-
-
-
Field Detail
-
DEFAULT_NAME
public static final java.lang.String DEFAULT_NAME
- See Also:
- Constant Field Values
-
-
Method Detail
-
getName
public java.lang.String getName()
Returns the lambda function name.- Returns:
- the lambda function name
-
singleton
public static LambdaBlock singleton(java.util.function.Function<NDArray,NDArray> lambda)
Creates aLambdaBlockfor a singleton function.- Parameters:
lambda- a function accepting a singletonNDListand returning another singletonNDList- Returns:
- a new
LambdaBlockfor the function
-
singleton
public static LambdaBlock singleton(java.util.function.Function<NDArray,NDArray> lambda, java.lang.String name)
Creates aLambdaBlockfor a singleton function.- Parameters:
lambda- a function accepting a singletonNDListand returning another singletonNDListname- the function name- Returns:
- a new
LambdaBlockfor the function
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)after initialization.- Specified by:
forwardInternalin classAbstractBaseBlock- Parameters:
parameterStore- the parameter storeinputs- the input NDListtraining- true for a training forward passparams- optional parameters- Returns:
- the output of the forward pass
-
getOutputShapes
public Shape[] getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.- Parameters:
inputShapes- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
loadParameters
public void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
Loads the parameters from the given input stream.- Specified by:
loadParametersin interfaceBlock- Overrides:
loadParametersin classAbstractBaseBlock- Parameters:
manager- an NDManager to create the parameter arraysis- the inputstream that stream the parameter values- Throws:
java.io.IOException- if an I/O error occursMalformedModelException- if the model file is corrupted or unsupported
-
-