Package ai.djl.nn.core
Class Embedding<T>
java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.core.Embedding<T>
- Type Parameters:
T
- the type of item that should be embedded and map to the array
- All Implemented Interfaces:
Block
,AbstractEmbedding<T>
,AbstractIndexedEmbedding<T>
- Direct Known Subclasses:
TrainableWordEmbedding
An Embedding block map a collection of items to 1-Dimensional representative
NDArray
s.-
Nested Class Summary
Nested ClassesModifier and TypeClassDescriptionstatic class
Embedding.BaseBuilder<T,
B extends Embedding.BaseBuilder<T, B>> protected class
protected class
-
Field Summary
FieldsModifier and TypeFieldDescriptionprotected Parameter
protected int
protected AbstractIndexedEmbedding<T>
protected int
protected SparseFormat
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Constructor Summary
ConstructorsModifierConstructorDescriptionprotected
Constructs a pretrained embedding.protected
Embedding
(NDArray embedding, SparseFormat format) Constructs a pretrained embedding.protected
Embedding
(Embedding.BaseBuilder<T, ?> baseBuilder) -
Method Summary
Modifier and TypeMethodDescriptionEmbeds an array of items.static NDList
embedding
(NDArray input, NDArray weight, SparseFormat sparse) A simple lookup table that looks up embeddings in a fixed dictionary and size.protected NDList
forwardInternal
(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Shape[]
getOutputShapes
(Shape[] inputShapes) Returns the expected output shapes of the block for the specified input shapes.void
loadParameters
(NDManager manager, DataInputStream is) Loads the parameters from the given input stream.void
Sets the shape ofParameter
s.void
Writes the parameters of the block to the given outputStream.Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.core.AbstractEmbedding
hasItem
Methods inherited from interface ai.djl.nn.core.AbstractIndexedEmbedding
decode, embed, encode, unembed
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getCustomMetadata, getOutputShapes
-
Field Details
-
numEmbeddings
protected int numEmbeddings -
embeddingSize
protected int embeddingSize -
sparseFormat
-
fallthroughEmbedding
-
embedding
-
-
Constructor Details
-
Embedding
-
Embedding
Constructs a pretrained embedding.- Parameters:
embedding
- the embedding array
-
Embedding
Constructs a pretrained embedding.Because it is created with preTrained data, it is created as a frozen block. If you with to update it, call
Block.freezeParameters(boolean)
.- Parameters:
embedding
- the embedding arrayformat
- whether to compute row sparse gradient in the backward calculation
-
-
Method Details
-
prepare
Sets the shape ofParameter
s.- Overrides:
prepare
in classAbstractBaseBlock
- Parameters:
inputShapes
- the shapes of inputs
-
getOutputShapes
Returns the expected output shapes of the block for the specified input shapes.- Specified by:
getOutputShapes
in interfaceBlock
- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String, Object> params) A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
saveParameters
Writes the parameters of the block to the given outputStream.- Specified by:
saveParameters
in interfaceBlock
- Overrides:
saveParameters
in classAbstractBaseBlock
- Parameters:
os
- the outputstream to save the parameters to- Throws:
IOException
- if an I/O error occurs
-
loadParameters
public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException Loads the parameters from the given input stream.- Specified by:
loadParameters
in interfaceBlock
- Overrides:
loadParameters
in classAbstractBaseBlock
- Parameters:
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter values- Throws:
IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupported
-
embed
Embeds an array of items.- Specified by:
embed
in interfaceAbstractEmbedding<T>
- Parameters:
manager
- the manager for the new embeddingsitems
- the items to embed- Returns:
- the embedding
NDArray
of Shape(items.length, embeddingSize)
-
embedding
A simple lookup table that looks up embeddings in a fixed dictionary and size.- Parameters:
input
- NDArray containing indices into the embedding matrixweight
- The embedding matrix with number of rows equal to the maximum possible index + 1, and number of columns equal to the embedding sizesparse
- SparseFormat of the gradient- Returns:
- output NDArray
-