public class TrainableWordEmbedding extends Embedding<java.lang.String> implements WordEmbedding
TrainableWordEmbedding
is an implementation of WordEmbedding
and Embedding
based on a SimpleVocabulary
. This WordEmbedding
is ideal when there
are no pre-trained embeddings available.Modifier and Type | Class and Description |
---|---|
static class |
TrainableWordEmbedding.Builder
A builder for a
TrainableWordEmbedding . |
Embedding.BaseBuilder<T,B extends Embedding.BaseBuilder<T,B>>, Embedding.DefaultEmbedding, Embedding.DefaultItem
embedding, embeddingSize, fallthroughEmbedding, numEmbeddings, sparseFormat
children, inputNames, inputShapes, parameters, version
Constructor and Description |
---|
TrainableWordEmbedding(NDArray embedding,
java.util.List<java.lang.String> items)
Constructs a pretrained embedding.
|
TrainableWordEmbedding(NDArray embedding,
java.util.List<java.lang.String> items,
SparseFormat sparseFormat)
Constructs a pretrained embedding.
|
TrainableWordEmbedding(TrainableWordEmbedding.Builder builder)
Constructs a new instance of
TrainableWordEmbedding from the TrainableWordEmbedding.Builder . |
TrainableWordEmbedding(Vocabulary vocabulary,
int embeddingSize)
Constructs a new instance of
TrainableWordEmbedding from a SimpleVocabulary
and a given embedding size. |
Modifier and Type | Method and Description |
---|---|
static TrainableWordEmbedding.Builder |
builder()
Creates a builder to build an
Embedding . |
java.lang.String |
decode(byte[] byteArray)
Decodes the given byte array into an object of input parameter type.
|
long |
embed(java.lang.String item)
Embeds an item.
|
NDArray |
embedWord(NDArray index)
Embeds the word after preprocessed using
WordEmbedding.preprocessWordToEmbed(String) . |
byte[] |
encode(java.lang.String input)
Encodes an object of input type into a byte array.
|
boolean |
hasItem(java.lang.String item)
Returns whether an item is in the embedding.
|
long |
preprocessWordToEmbed(java.lang.String word)
Pre-processes the word to embed into an array to pass into the model.
|
java.util.Optional<java.lang.String> |
unembed(long index)
Returns the item corresponding to the given index.
|
java.lang.String |
unembedWord(NDArray word)
Returns the closest matching word for the given index.
|
boolean |
vocabularyContains(java.lang.String word)
Returns whether an embedding exists for a word.
|
embed, embedding, forwardInternal, getOutputShapes, loadParameters, prepare, saveParameters
addChildBlock, addParameter, beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getChildren, getDirectParameters, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
embedWord, embedWord
forward, validateLayout
public TrainableWordEmbedding(TrainableWordEmbedding.Builder builder)
TrainableWordEmbedding
from the TrainableWordEmbedding.Builder
.builder
- the TrainableWordEmbedding.Builder
public TrainableWordEmbedding(Vocabulary vocabulary, int embeddingSize)
TrainableWordEmbedding
from a SimpleVocabulary
and a given embedding size.vocabulary
- a Vocabulary
to get tokens fromembeddingSize
- the required embedding sizepublic TrainableWordEmbedding(NDArray embedding, java.util.List<java.lang.String> items)
embedding
- the embedding arrayitems
- the items in the embedding (in matching order to the embedding array)public TrainableWordEmbedding(NDArray embedding, java.util.List<java.lang.String> items, SparseFormat sparseFormat)
embedding
- the embedding arrayitems
- the items in the embedding (in matching order to the embedding array)sparseFormat
- whether to compute row sparse gradient in the backward calculationpublic boolean vocabularyContains(java.lang.String word)
vocabularyContains
in interface WordEmbedding
word
- the word to checkpublic long preprocessWordToEmbed(java.lang.String word)
Make sure to call WordEmbedding.embedWord(NDManager, long)
after this.
preprocessWordToEmbed
in interface WordEmbedding
word
- the word to embedpublic NDArray embedWord(NDArray index) throws EmbeddingException
WordEmbedding.preprocessWordToEmbed(String)
.embedWord
in interface WordEmbedding
index
- the index of the word to embedEmbeddingException
- if there is an error while trying to embedpublic java.lang.String unembedWord(NDArray word)
unembedWord
in interface WordEmbedding
word
- the word embedding to find the matching string word for.public byte[] encode(java.lang.String input)
Embedding
objects.encode
in interface AbstractIndexedEmbedding<java.lang.String>
input
- the input object to be encodedpublic java.lang.String decode(byte[] byteArray)
decode
in interface AbstractIndexedEmbedding<java.lang.String>
byteArray
- the byte array to be decodedpublic long embed(java.lang.String item)
embed
in interface AbstractIndexedEmbedding<java.lang.String>
item
- the item to embedpublic java.util.Optional<java.lang.String> unembed(long index)
unembed
in interface AbstractIndexedEmbedding<java.lang.String>
index
- the indexpublic static TrainableWordEmbedding.Builder builder()
Embedding
.public boolean hasItem(java.lang.String item)
hasItem
in interface AbstractEmbedding<java.lang.String>
item
- the item to test