Package ai.djl.modality.nlp.embedding
Class TrainableWordEmbedding
- All Implemented Interfaces:
WordEmbedding
,Block
,AbstractEmbedding<String>
,AbstractIndexedEmbedding<String>
TrainableWordEmbedding
is an implementation of WordEmbedding
and Embedding
based on a DefaultVocabulary
. This WordEmbedding
is ideal when there
are no pre-trained embeddings available.-
Nested Class Summary
Nested ClassesNested classes/interfaces inherited from class ai.djl.nn.core.Embedding
Embedding.BaseBuilder<T,
B extends Embedding.BaseBuilder<T, B>>, Embedding.DefaultEmbedding, Embedding.DefaultItem -
Field Summary
Fields inherited from class ai.djl.nn.core.Embedding
embedding, embeddingSize, fallthroughEmbedding, numEmbeddings, sparseFormat
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, outputDataTypes, version
-
Constructor Summary
ConstructorsConstructorDescriptionConstructs a new instance ofTrainableWordEmbedding
from theTrainableWordEmbedding.Builder
.TrainableWordEmbedding
(Vocabulary vocabulary, int embeddingSize) Constructs a new instance ofTrainableWordEmbedding
from aDefaultVocabulary
and a given embedding size. -
Method Summary
Modifier and TypeMethodDescriptionbuilder()
Creates a builder to build anEmbedding
.decode
(byte[] byteArray) Decodes the given byte array into an object of input parameter type.long
Embeds an item.Embeds the word after preprocessed usingWordEmbedding.preprocessWordToEmbed(String)
.byte[]
Encodes an object of input type into a byte array.static TrainableWordEmbedding
fromPretrained
(NDArray embedding, List<String> items) Constructs a pretrained embedding.static TrainableWordEmbedding
fromPretrained
(NDArray embedding, List<String> items, SparseFormat sparseFormat) Constructs a pretrained embedding.boolean
Returns whether an item is in the embedding.long
preprocessWordToEmbed
(String word) Pre-processes the word to embed into an array to pass into the model.unembed
(long index) Returns the item corresponding to the given index.unembedWord
(NDArray word) Returns the closest matching word for the given index.boolean
vocabularyContains
(String word) Returns whether an embedding exists for a word.Methods inherited from class ai.djl.nn.core.Embedding
embed, embedding, forwardInternal, getOutputShapes, loadParameters, prepare, saveParameters
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters, freezeParameters, getOutputShapes
Methods inherited from interface ai.djl.modality.nlp.embedding.WordEmbedding
embedWord, embedWord
-
Constructor Details
-
TrainableWordEmbedding
Constructs a new instance ofTrainableWordEmbedding
from theTrainableWordEmbedding.Builder
.- Parameters:
builder
- theTrainableWordEmbedding.Builder
-
TrainableWordEmbedding
Constructs a new instance ofTrainableWordEmbedding
from aDefaultVocabulary
and a given embedding size.- Parameters:
vocabulary
- aVocabulary
to get tokens fromembeddingSize
- the required embedding size
-
-
Method Details
-
fromPretrained
Constructs a pretrained embedding.Because it is created with preTrained data, it is created as a frozen block. If you with to update it, call
Block.freezeParameters(boolean)
.- Parameters:
embedding
- the embedding arrayitems
- the items in the embedding (in matching order to the embedding array)- Returns:
- the created embedding
-
fromPretrained
public static TrainableWordEmbedding fromPretrained(NDArray embedding, List<String> items, SparseFormat sparseFormat) Constructs a pretrained embedding.Because it is created with preTrained data, it is created as a frozen block. If you with to update it, call
Block.freezeParameters(boolean)
.- Parameters:
embedding
- the embedding arrayitems
- the items in the embedding (in matching order to the embedding array)sparseFormat
- whether to compute row sparse gradient in the backward calculation- Returns:
- the created embedding
-
vocabularyContains
Returns whether an embedding exists for a word.- Specified by:
vocabularyContains
in interfaceWordEmbedding
- Parameters:
word
- the word to check- Returns:
- true if an embedding exists
-
preprocessWordToEmbed
Pre-processes the word to embed into an array to pass into the model.Make sure to call
WordEmbedding.embedWord(NDManager, long)
after this.- Specified by:
preprocessWordToEmbed
in interfaceWordEmbedding
- Parameters:
word
- the word to embed- Returns:
- the word that is ready to embed
-
embedWord
Embeds the word after preprocessed usingWordEmbedding.preprocessWordToEmbed(String)
.- Specified by:
embedWord
in interfaceWordEmbedding
- Parameters:
index
- the index of the word to embed- Returns:
- the embedded word
- Throws:
EmbeddingException
- if there is an error while trying to embed
-
unembedWord
Returns the closest matching word for the given index.- Specified by:
unembedWord
in interfaceWordEmbedding
- Parameters:
word
- the word embedding to find the matching string word for.- Returns:
- a word similar to the passed in embedding
-
encode
Encodes an object of input type into a byte array. This is used in saving and loading theEmbedding
objects.- Specified by:
encode
in interfaceAbstractIndexedEmbedding<String>
- Parameters:
input
- the input object to be encoded- Returns:
- the encoded byte array.
-
decode
Decodes the given byte array into an object of input parameter type.- Specified by:
decode
in interfaceAbstractIndexedEmbedding<String>
- Parameters:
byteArray
- the byte array to be decoded- Returns:
- the decode object of input parameter type
-
embed
Embeds an item.- Specified by:
embed
in interfaceAbstractIndexedEmbedding<String>
- Parameters:
item
- the item to embed- Returns:
- the index of the item in the embedding
-
unembed
Returns the item corresponding to the given index.- Specified by:
unembed
in interfaceAbstractIndexedEmbedding<String>
- Parameters:
index
- the index- Returns:
- the item corresponding to the given index
-
builder
Creates a builder to build anEmbedding
.- Returns:
- a new builder
-
hasItem
Returns whether an item is in the embedding.- Specified by:
hasItem
in interfaceAbstractEmbedding<String>
- Parameters:
item
- the item to test- Returns:
- true if the item is in the embedding
-