Package ai.djl.nn.core
Class Embedding.BaseBuilder<T,B extends Embedding.BaseBuilder<T,B>>
- java.lang.Object
-
- ai.djl.nn.core.Embedding.BaseBuilder<T,B>
-
- Type Parameters:
T
- the type of object to embed
- Direct Known Subclasses:
TrainableWordEmbedding.Builder
public abstract static class Embedding.BaseBuilder<T,B extends Embedding.BaseBuilder<T,B>> extends java.lang.Object
-
-
Field Summary
Fields Modifier and Type Field Description protected T
defaultItem
protected int
embeddingSize
protected java.lang.Class<T>
embeddingType
protected AbstractIndexedEmbedding<T>
fallthrough
protected int
numEmbeddings
protected SparseFormat
sparseFormat
protected boolean
useDefault
-
Constructor Summary
Constructors Modifier Constructor Description protected
BaseBuilder()
-
Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods Modifier and Type Method Description java.lang.Class<T>
getEmbeddingType()
Returns the embedded type.B
optDefaultItem(T defaultItem)
Sets whether to use a default item's embedding for undefined items.B
optFallthrough(AbstractIndexedEmbedding<T> fallthrough)
Sets a custom handler for items not found in the embedding.B
optNumEmbeddings(int numEmbeddings)
Sets the size of the dictionary of embeddings.B
optSparseFormat(SparseFormat sparseFormat)
Sets the optional parameter whether to compute row sparse gradient in the backward calculation.B
optUseDefault(boolean useDefault)
Sets whether to use a default embedding for undefined items (default true).protected abstract B
self()
Returns this {code Builder} object.B
setEmbeddingSize(int embeddingSize)
Sets the size of the embeddings.protected abstract B
setType(java.lang.Class<T> embeddingType)
Creates a newEmbedding.BaseBuilder
with the specified embedding type.
-
-
-
Field Detail
-
embeddingType
protected java.lang.Class<T> embeddingType
-
numEmbeddings
protected int numEmbeddings
-
embeddingSize
protected int embeddingSize
-
useDefault
protected boolean useDefault
-
defaultItem
protected T defaultItem
-
fallthrough
protected AbstractIndexedEmbedding<T> fallthrough
-
sparseFormat
protected SparseFormat sparseFormat
-
-
Method Detail
-
getEmbeddingType
public java.lang.Class<T> getEmbeddingType()
Returns the embedded type.- Returns:
- the embedded type
-
setType
protected abstract B setType(java.lang.Class<T> embeddingType)
Creates a newEmbedding.BaseBuilder
with the specified embedding type.- Parameters:
embeddingType
- the embedding class- Returns:
- a new
Embedding.BaseBuilder
class with the specified embedding type
-
setEmbeddingSize
public B setEmbeddingSize(int embeddingSize)
Sets the size of the embeddings.- Parameters:
embeddingSize
- the size of the 1D embedding array- Returns:
- this Builder
-
optNumEmbeddings
public B optNumEmbeddings(int numEmbeddings)
Sets the size of the dictionary of embeddings.- Parameters:
numEmbeddings
- the size of the dictionary of embeddings- Returns:
- this Builder
-
optUseDefault
public B optUseDefault(boolean useDefault)
Sets whether to use a default embedding for undefined items (default true).- Parameters:
useDefault
- true to provide a default embedding and false to throw anIllegalArgumentException
when the item can not be found- Returns:
- this Builder
-
optDefaultItem
public B optDefaultItem(T defaultItem)
Sets whether to use a default item's embedding for undefined items.- Parameters:
defaultItem
- the item to use as a default.- Returns:
- this Builder
-
optFallthrough
public B optFallthrough(AbstractIndexedEmbedding<T> fallthrough)
Sets a custom handler for items not found in the embedding.See the standard fallthrough handlers
optUseDefault(boolean)
andoptDefaultItem(Object)
.- Parameters:
fallthrough
- the embedding to handle default cases.- Returns:
- this Builder
-
optSparseFormat
public B optSparseFormat(SparseFormat sparseFormat)
Sets the optional parameter whether to compute row sparse gradient in the backward calculation. If set to True, the grad’s storage type is row_sparse.- Parameters:
sparseFormat
- whether to compute row sparse gradient in the backward calculation- Returns:
- this Builder
-
self
protected abstract B self()
Returns this {code Builder} object.- Returns:
- this
BaseBuilder
-
-