Package ai.djl.modality.nlp.embedding
Class TrainableWordEmbedding
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.nn.core.Embedding<java.lang.String>
-
- ai.djl.modality.nlp.embedding.TrainableWordEmbedding
-
- All Implemented Interfaces:
WordEmbedding
,Block
,AbstractEmbedding<java.lang.String>
,AbstractIndexedEmbedding<java.lang.String>
public class TrainableWordEmbedding extends Embedding<java.lang.String> implements WordEmbedding
TrainableWordEmbedding
is an implementation ofWordEmbedding
andEmbedding
based on aDefaultVocabulary
. ThisWordEmbedding
is ideal when there are no pre-trained embeddings available.
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
TrainableWordEmbedding.Builder
A builder for aTrainableWordEmbedding
.-
Nested classes/interfaces inherited from class ai.djl.nn.core.Embedding
Embedding.BaseBuilder<T,B extends Embedding.BaseBuilder<T,B>>, Embedding.DefaultEmbedding, Embedding.DefaultItem
-
-
Field Summary
-
Fields inherited from class ai.djl.nn.core.Embedding
embedding, embeddingSize, fallthroughEmbedding, numEmbeddings, sparseFormat
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, version
-
-
Constructor Summary
Constructors Constructor Description TrainableWordEmbedding(TrainableWordEmbedding.Builder builder)
Constructs a new instance ofTrainableWordEmbedding
from theTrainableWordEmbedding.Builder
.TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize)
Constructs a new instance ofTrainableWordEmbedding
from aDefaultVocabulary
and a given embedding size.TrainableWordEmbedding(NDArray embedding, java.util.List<java.lang.String> items)
Constructs a pretrained embedding.TrainableWordEmbedding(NDArray embedding, java.util.List<java.lang.String> items, SparseFormat sparseFormat)
Constructs a pretrained embedding.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description static TrainableWordEmbedding.Builder
builder()
Creates a builder to build anEmbedding
.java.lang.String
decode(byte[] byteArray)
Decodes the given byte array into an object of input parameter type.long
embed(java.lang.String item)
Embeds an item.NDArray
embedWord(NDArray index)
Embeds the word after preprocessed usingWordEmbedding.preprocessWordToEmbed(String)
.byte[]
encode(java.lang.String input)
Encodes an object of input type into a byte array.boolean
hasItem(java.lang.String item)
Returns whether an item is in the embedding.long
preprocessWordToEmbed(java.lang.String word)
Pre-processes the word to embed into an array to pass into the model.java.util.Optional<java.lang.String>
unembed(long index)
Returns the item corresponding to the given index.java.lang.String
unembedWord(NDArray word)
Returns the closest matching word for the given index.boolean
vocabularyContains(java.lang.String word)
Returns whether an embedding exists for a word.-
Methods inherited from class ai.djl.nn.core.Embedding
embed, embedding, forwardInternal, getOutputShapes, loadParameters, prepare, saveParameters
-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters
-
Methods inherited from interface ai.djl.modality.nlp.embedding.WordEmbedding
embedWord, embedWord
-
-
-
-
Constructor Detail
-
TrainableWordEmbedding
public TrainableWordEmbedding(TrainableWordEmbedding.Builder builder)
Constructs a new instance ofTrainableWordEmbedding
from theTrainableWordEmbedding.Builder
.- Parameters:
builder
- theTrainableWordEmbedding.Builder
-
TrainableWordEmbedding
public TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize)
Constructs a new instance ofTrainableWordEmbedding
from aDefaultVocabulary
and a given embedding size.- Parameters:
vocabulary
- aVocabulary
to get tokens fromembeddingSize
- the required embedding size
-
TrainableWordEmbedding
public TrainableWordEmbedding(NDArray embedding, java.util.List<java.lang.String> items)
Constructs a pretrained embedding.- Parameters:
embedding
- the embedding arrayitems
- the items in the embedding (in matching order to the embedding array)
-
TrainableWordEmbedding
public TrainableWordEmbedding(NDArray embedding, java.util.List<java.lang.String> items, SparseFormat sparseFormat)
Constructs a pretrained embedding.- Parameters:
embedding
- the embedding arrayitems
- the items in the embedding (in matching order to the embedding array)sparseFormat
- whether to compute row sparse gradient in the backward calculation
-
-
Method Detail
-
vocabularyContains
public boolean vocabularyContains(java.lang.String word)
Returns whether an embedding exists for a word.- Specified by:
vocabularyContains
in interfaceWordEmbedding
- Parameters:
word
- the word to check- Returns:
- true if an embedding exists
-
preprocessWordToEmbed
public long preprocessWordToEmbed(java.lang.String word)
Pre-processes the word to embed into an array to pass into the model.Make sure to call
WordEmbedding.embedWord(NDManager, long)
after this.- Specified by:
preprocessWordToEmbed
in interfaceWordEmbedding
- Parameters:
word
- the word to embed- Returns:
- the word that is ready to embed
-
embedWord
public NDArray embedWord(NDArray index) throws EmbeddingException
Embeds the word after preprocessed usingWordEmbedding.preprocessWordToEmbed(String)
.- Specified by:
embedWord
in interfaceWordEmbedding
- Parameters:
index
- the index of the word to embed- Returns:
- the embedded word
- Throws:
EmbeddingException
- if there is an error while trying to embed
-
unembedWord
public java.lang.String unembedWord(NDArray word)
Returns the closest matching word for the given index.- Specified by:
unembedWord
in interfaceWordEmbedding
- Parameters:
word
- the word embedding to find the matching string word for.- Returns:
- a word similar to the passed in embedding
-
encode
public byte[] encode(java.lang.String input)
Encodes an object of input type into a byte array. This is used in saving and loading theEmbedding
objects.- Specified by:
encode
in interfaceAbstractIndexedEmbedding<java.lang.String>
- Parameters:
input
- the input object to be encoded- Returns:
- the encoded byte array.
-
decode
public java.lang.String decode(byte[] byteArray)
Decodes the given byte array into an object of input parameter type.- Specified by:
decode
in interfaceAbstractIndexedEmbedding<java.lang.String>
- Parameters:
byteArray
- the byte array to be decoded- Returns:
- the decode object of input parameter type
-
embed
public long embed(java.lang.String item)
Embeds an item.- Specified by:
embed
in interfaceAbstractIndexedEmbedding<java.lang.String>
- Parameters:
item
- the item to embed- Returns:
- the index of the item in the embedding
-
unembed
public java.util.Optional<java.lang.String> unembed(long index)
Returns the item corresponding to the given index.- Specified by:
unembed
in interfaceAbstractIndexedEmbedding<java.lang.String>
- Parameters:
index
- the index- Returns:
- the item corresponding to the given index
-
builder
public static TrainableWordEmbedding.Builder builder()
Creates a builder to build anEmbedding
.- Returns:
- a new builder
-
hasItem
public boolean hasItem(java.lang.String item)
Returns whether an item is in the embedding.- Specified by:
hasItem
in interfaceAbstractEmbedding<java.lang.String>
- Parameters:
item
- the item to test- Returns:
- true if the item is in the embedding
-
-