T
- the type of item that should be embedded and map to the arraypublic abstract class Embedding<T> extends AbstractBlock implements AbstractIndexedEmbedding<T>
NDArray
s.Modifier and Type | Class and Description |
---|---|
static class |
Embedding.BaseBuilder<T,B extends Embedding.BaseBuilder<T,B>>
|
protected class |
Embedding.DefaultEmbedding |
protected class |
Embedding.DefaultItem |
Modifier and Type | Field and Description |
---|---|
protected DataType |
dataType |
protected Parameter |
embedding |
protected int |
embeddingSize |
protected AbstractIndexedEmbedding<T> |
fallthroughEmbedding |
protected int |
numItems |
protected boolean |
sparseGrad |
children, inputNames, inputShapes, parameters, parameterShapeCallbacks, version
Modifier | Constructor and Description |
---|---|
protected |
Embedding(Embedding.BaseBuilder<T,?> baseBuilder) |
|
Embedding(NDArray embedding)
Constructs a pretrained embedding.
|
|
Embedding(NDArray embedding,
boolean sparseGrad)
Constructs a pretrained embedding.
|
Modifier and Type | Method and Description |
---|---|
NDArray |
embed(NDManager manager,
T[] items)
Embeds an array of items.
|
NDList |
forward(ParameterStore parameterStore,
NDList inputs,
boolean training,
ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
Applies the operating function of the block once.
|
Shape[] |
getOutputShapes(NDManager manager,
Shape[] inputShapes)
Returns the expected output shapes of the block for the specified input shapes.
|
void |
loadParameters(NDManager manager,
java.io.DataInputStream is)
Loads the parameters from the given input stream.
|
void |
saveParameters(java.io.DataOutputStream os)
Writes the parameters of the block to the given outputStream.
|
addChildBlock, addParameter, addParameter, addParameter, beforeInitialize, cast, clear, describeInput, getChildren, getDirectParameters, getParameters, getParameterShape, initialize, initializeChildBlocks, isInitialized, loadMetadata, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, toString
clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
decode, embed, encode, unembed
hasItem
forward, forward, validateLayout
protected int embeddingSize
protected boolean sparseGrad
protected DataType dataType
protected int numItems
protected AbstractIndexedEmbedding<T> fallthroughEmbedding
protected Parameter embedding
protected Embedding(Embedding.BaseBuilder<T,?> baseBuilder)
public Embedding(NDArray embedding)
embedding
- the embedding arraypublic Embedding(NDArray embedding, boolean sparseGrad)
embedding
- the embedding arraysparseGrad
- whether to compute row sparse gradient in the backward calculationpublic Shape[] getOutputShapes(NDManager manager, Shape[] inputShapes)
getOutputShapes
in interface Block
manager
- an NDManagerinputShapes
- the shapes of the inputspublic NDList forward(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,java.lang.Object> params)
public void saveParameters(java.io.DataOutputStream os) throws java.io.IOException
saveParameters
in interface Block
saveParameters
in class AbstractBlock
os
- the outputstream to save the parameters tojava.io.IOException
- if an I/O error occurspublic void loadParameters(NDManager manager, java.io.DataInputStream is) throws java.io.IOException, MalformedModelException
loadParameters
in interface Block
loadParameters
in class AbstractBlock
manager
- an NDManager to create the parameter arraysis
- the inputstream that stream the parameter valuesjava.io.IOException
- if an I/O error occursMalformedModelException
- if the model file is corrupted or unsupported