Package ai.djl.modality.nlp.embedding
Class TrainableTextEmbedding
- java.lang.Object
-
- ai.djl.nn.AbstractBaseBlock
-
- ai.djl.nn.AbstractBlock
-
- ai.djl.modality.nlp.embedding.TrainableTextEmbedding
-
- All Implemented Interfaces:
TextEmbedding
,Block
public class TrainableTextEmbedding extends AbstractBlock implements TextEmbedding
TrainableTextEmbedding
is an implementation ofTextEmbedding
based onTrainableWordEmbedding
block. ThisTextEmbedding
is ideal when there are no pre-trained embeddings available, or when the pre-trained embedding needs to be further trained.
-
-
Field Summary
-
Fields inherited from class ai.djl.nn.AbstractBlock
children, parameters
-
Fields inherited from class ai.djl.nn.AbstractBaseBlock
inputNames, inputShapes, version
-
-
Constructor Summary
Constructors Constructor Description TrainableTextEmbedding(TrainableWordEmbedding wordEmbedding)
Constructs aTrainableTextEmbedding
.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDArray
embedText(NDArray textIndices)
Embeds the text after preprocessed usingTextEmbedding.preprocessTextToEmbed(List)
.protected NDList
forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.Shape[]
getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.void
initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the Child blocks of this block.long[]
preprocessTextToEmbed(java.util.List<java.lang.String> text)
Preprocesses the text to embed into an array to pass into the model.java.util.List<java.lang.String>
unembedText(NDArray textEmbedding)
Returns the closest matching text for a given embedding.-
Methods inherited from class ai.djl.nn.AbstractBlock
addChildBlock, addChildBlock, addChildBlockSingleton, addParameter, getChildren, getDirectParameters
-
Methods inherited from class ai.djl.nn.AbstractBaseBlock
beforeInitialize, cast, clear, describeInput, forward, forward, forwardInternal, getInputShapes, getParameters, initialize, isInitialized, loadMetadata, loadParameters, prepare, readInputShapes, saveInputShapes, saveMetadata, saveParameters, setInitializer, setInitializer, setInitializer, toString
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
-
Methods inherited from interface ai.djl.nn.Block
forward, freezeParameters
-
Methods inherited from interface ai.djl.modality.nlp.embedding.TextEmbedding
embedText, embedText
-
-
-
-
Constructor Detail
-
TrainableTextEmbedding
public TrainableTextEmbedding(TrainableWordEmbedding wordEmbedding)
Constructs aTrainableTextEmbedding
.- Parameters:
wordEmbedding
- the word embedding to embed each word
-
-
Method Detail
-
preprocessTextToEmbed
public long[] preprocessTextToEmbed(java.util.List<java.lang.String> text)
Preprocesses the text to embed into an array to pass into the model.Make sure to call
TextEmbedding.embedText(NDManager, long[])
after this.- Specified by:
preprocessTextToEmbed
in interfaceTextEmbedding
- Parameters:
text
- the text to embed- Returns:
- the indices of text that is ready to embed
-
embedText
public NDArray embedText(NDArray textIndices) throws EmbeddingException
Embeds the text after preprocessed usingTextEmbedding.preprocessTextToEmbed(List)
.- Specified by:
embedText
in interfaceTextEmbedding
- Parameters:
textIndices
- the indices of text to embed- Returns:
- the embedded text
- Throws:
EmbeddingException
- if there is an error while trying to embed
-
unembedText
public java.util.List<java.lang.String> unembedText(NDArray textEmbedding)
Returns the closest matching text for a given embedding.- Specified by:
unembedText
in interfaceTextEmbedding
- Parameters:
textEmbedding
- the text embedding to find the matching string text for.- Returns:
- text similar to the passed in embedding
-
forwardInternal
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
A helper forBlock.forward(ParameterStore, NDList, boolean, PairList)
after initialization.- Specified by:
forwardInternal
in classAbstractBaseBlock
- Parameters:
parameterStore
- the parameter storeinputs
- the input NDListtraining
- true for a training forward passparams
- optional parameters- Returns:
- the output of the forward pass
-
initializeChildBlocks
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.- Overrides:
initializeChildBlocks
in classAbstractBaseBlock
- Parameters:
manager
- the manager to use for initializationdataType
- the requested data typeinputShapes
- the expected input shapes for this block
-
getOutputShapes
public Shape[] getOutputShapes(Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.- Specified by:
getOutputShapes
in interfaceBlock
- Parameters:
inputShapes
- the shapes of the inputs- Returns:
- the expected output shapes of the block
-
-