Package ai.djl.nn.core
Class ConstantEmbedding
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.core.ConstantEmbedding
-
- All Implemented Interfaces:
Block
,AbstractEmbedding
,AbstractIndexedEmbedding
public class ConstantEmbedding extends AbstractBlock implements AbstractIndexedEmbedding
AnAbstractIndexedEmbedding
that always returns a constant value.
-
-
Field Summary
Fields Modifier and Type Field Description protected NDArray
embedding
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
-
Constructor Summary
Constructors Constructor Description ConstantEmbedding(NDArray embedding)
Constructs a constant embedding with the given constant.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description java.lang.Object
decode(byte[] byteArray)
Decodes the given byte array into an object of input parameter type.NDArray
embed(NDManager manager, java.lang.Object[] items)
Embeds an array of items.long
embed(java.lang.Object item)
Embeds an item.byte[]
encode(java.lang.Object input)
Encodes an object of input type into a byte array.protected NDList
forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Shape[]
getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.boolean
hasItem(java.lang.Object item)
Returns whether an item is in the embedding.void
loadParameters(NDManager manager, java.io.DataInputStream is)
Loads the parameters from the given input stream.void
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.java.util.Optional<?>
unembed(long index)
Returns the item corresponding to the given index.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, getOutputShapes
-
-
-
-
Field Detail
-
embedding
protected NDArray embedding
-
-
Constructor Detail
-
ConstantEmbedding
public ConstantEmbedding(NDArray embedding)
Constructs a constant embedding with the given constant.The constant is assumed to be a fixed value, and starts out as frozen. To unfreeze, use
Block.freezeParameters(boolean)
.- Parameters:
embedding
- the value to return for all embeddings
-
-
Method Detail
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
getOutputShapes
public Shape[] getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.- Specified by:
getOutputShapes
in interfaceBlock
- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
saveParameters
public void saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.- Specified by:
saveParameters
in interfaceBlock
- Overrides:
saveParameters
in classAbstractBaseBlock
- Parameters:
os
- the outputstream to save the parameters to
-
loadParameters
public void loadParameters(NDManager manager, java.io.DataInputStream is)
Loads the parameters from the given input stream.- Specified by:
loadParameters
in interfaceBlock
- Overrides:
loadParameters
in classAbstractBaseBlock
- Parameters:
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter values
-
unembed
public java.util.Optional<?> unembed(long index)
Returns the item corresponding to the given index.- Specified by:
unembed
in interfaceAbstractIndexedEmbedding
- Parameters:
index
- the index- Returns:
- the item corresponding to the given index
-
encode
public byte[] encode(java.lang.Object input)
Encodes an object of input type into a byte array. This is used in saving and loading theEmbedding
objects.- Specified by:
encode
in interfaceAbstractIndexedEmbedding
- Parameters:
input
- the input object to be encoded- Returns:
- the encoded byte array.
-
decode
public java.lang.Object decode(byte[] byteArray)
Decodes the given byte array into an object of input parameter type.- Specified by:
decode
in interfaceAbstractIndexedEmbedding
- Parameters:
byteArray
- the byte array to be decoded- Returns:
- the decode object of input parameter type
-
embed
public long embed(java.lang.Object item)
Embeds an item.- Specified by:
embed
in interfaceAbstractIndexedEmbedding
- Parameters:
item
- the item to embed- Returns:
- the index of the item in the embedding
-
embed
public NDArray embed(NDManager manager, java.lang.Object[] items)
Embeds an array of items.- Specified by:
embed
in interfaceAbstractEmbedding
- Parameters:
manager
- the manager for the new embeddingsitems
- the items to embed- Returns:
- the embedding
NDArray
of Shape(items.length, embeddingSize)
-
hasItem
public boolean hasItem(java.lang.Object item)
Returns whether an item is in the embedding.- Specified by:
hasItem
in interfaceAbstractEmbedding
- Parameters:
item
- the item to test- Returns:
- true if the item is in the embedding
-
-