Package ai.djl.nn.core
Class ConstantEmbedding
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.core.ConstantEmbedding
- All Implemented Interfaces:
Block
,AbstractEmbedding
,AbstractIndexedEmbedding
An
AbstractIndexedEmbedding
that always returns a constant value.-
Field Summary
FieldsFields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Constructor Summary
ConstructorsConstructorDescriptionConstantEmbedding
(NDArray embedding) Constructs a constant embedding with the given constant. -
Method Summary
Modifier and TypeMethodDescriptiondecode
(byte[] byteArray) Decodes the given byte array into an object of input parameter type.Embeds an array of items.long
Embeds an item.byte[]
Encodes an object of input type into a byte array.protected NDList
forwardInternal
(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Shape[]
getOutputShapes
(Shape[] inputShapes) Returns the expected output shapes of the block for the specified input shapes.boolean
Returns whether an item is in the embedding.void
loadParameters
(NDManager manager, DataInputStream is) Loads the parameters from the given input stream.void
Writes the parameters of the block to the given outputStream.Optional<?>
unembed
(long index) Returns the item corresponding to the given index.Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
-
Field Details
-
embedding
-
-
Constructor Details
-
ConstantEmbedding
Constructs a constant embedding with the given constant.The constant is assumed to be a fixed value, and starts out as frozen. To unfreeze, use
Block.freezeParameters(boolean)
.- Parameters:
embedding
- the value to return for all embeddings
-
-
Method Details
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
getOutputShapes
Returns the expected output shapes of the block for the specified input shapes.- Specified by:
getOutputShapes
in interfaceBlock
- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
saveParameters
Writes the parameters of the block to the given outputStream.- Specified by:
saveParameters
in interfaceBlock
- Overrides:
saveParameters
in classAbstractBaseBlock
- Parameters:
os
- the outputstream to save the parameters to
-
loadParameters
Loads the parameters from the given input stream.- Specified by:
loadParameters
in interfaceBlock
- Overrides:
loadParameters
in classAbstractBaseBlock
- Parameters:
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter values
-
unembed
Returns the item corresponding to the given index.- Specified by:
unembed
in interfaceAbstractIndexedEmbedding
- Parameters:
index
- the index- Returns:
- the item corresponding to the given index
-
encode
Encodes an object of input type into a byte array. This is used in saving and loading theEmbedding
objects.- Specified by:
encode
in interfaceAbstractIndexedEmbedding
- Parameters:
input
- the input object to be encoded- Returns:
- the encoded byte array.
-
decode
Decodes the given byte array into an object of input parameter type.- Specified by:
decode
in interfaceAbstractIndexedEmbedding
- Parameters:
byteArray
- the byte array to be decoded- Returns:
- the decode object of input parameter type
-
embed
Embeds an item.- Specified by:
embed
in interfaceAbstractIndexedEmbedding
- Parameters:
item
- the item to embed- Returns:
- the index of the item in the embedding
-
embed
Embeds an array of items.- Specified by:
embed
in interfaceAbstractEmbedding
- Parameters:
manager
- the manager for the new embeddingsitems
- the items to embed- Returns:
- the embedding
NDArray
of Shape(items.length, embeddingSize)
-
hasItem
Returns whether an item is in the embedding.- Specified by:
hasItem
in interfaceAbstractEmbedding
- Parameters:
item
- the item to test- Returns:
- true if the item is in the embedding
-