Class ConstantEmbedding

All Implemented Interfaces:
Block, AbstractEmbedding, AbstractIndexedEmbedding

public class ConstantEmbedding extends AbstractBlock implements AbstractIndexedEmbedding
An AbstractIndexedEmbedding that always returns a constant value.
  • Field Details

    • embedding

      protected NDArray embedding
  • Constructor Details

    • ConstantEmbedding

      public ConstantEmbedding(NDArray embedding)
      Constructs a constant embedding with the given constant.

      The constant is assumed to be a fixed value, and starts out as frozen. To unfreeze, use Block.freezeParameters(boolean).

      Parameters:
      embedding - the value to return for all embeddings
  • Method Details

    • forwardInternal

      protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      parameterStore - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputShapes)
      Returns the expected output shapes of the block for the specified input shapes.
      Specified by:
      getOutputShapes in interface Block
      Parameters:
      inputShapes - the shapes of the inputs
      Returns:
      the expected output shapes of the block
    • saveParameters

      public void saveParameters(DataOutputStream os)
      Writes the parameters of the block to the given outputStream.
      Specified by:
      saveParameters in interface Block
      Overrides:
      saveParameters in class AbstractBaseBlock
      Parameters:
      os - the outputstream to save the parameters to
    • loadParameters

      public void loadParameters(NDManager manager, DataInputStream is)
      Loads the parameters from the given input stream.
      Specified by:
      loadParameters in interface Block
      Overrides:
      loadParameters in class AbstractBaseBlock
      Parameters:
      manager - an NDManager to create the parameter arrays
      is - the inputstream that stream the parameter values
    • unembed

      public Optional<?> unembed(long index)
      Returns the item corresponding to the given index.
      Specified by:
      unembed in interface AbstractIndexedEmbedding
      Parameters:
      index - the index
      Returns:
      the item corresponding to the given index
    • encode

      public byte[] encode(Object input)
      Encodes an object of input type into a byte array. This is used in saving and loading the Embedding objects.
      Specified by:
      encode in interface AbstractIndexedEmbedding
      Parameters:
      input - the input object to be encoded
      Returns:
      the encoded byte array.
    • decode

      public Object decode(byte[] byteArray)
      Decodes the given byte array into an object of input parameter type.
      Specified by:
      decode in interface AbstractIndexedEmbedding
      Parameters:
      byteArray - the byte array to be decoded
      Returns:
      the decode object of input parameter type
    • embed

      public long embed(Object item)
      Embeds an item.
      Specified by:
      embed in interface AbstractIndexedEmbedding
      Parameters:
      item - the item to embed
      Returns:
      the index of the item in the embedding
    • embed

      public NDArray embed(NDManager manager, Object[] items)
      Embeds an array of items.
      Specified by:
      embed in interface AbstractEmbedding
      Parameters:
      manager - the manager for the new embeddings
      items - the items to embed
      Returns:
      the embedding NDArray of Shape(items.length, embeddingSize)
    • hasItem

      public boolean hasItem(Object item)
      Returns whether an item is in the embedding.
      Specified by:
      hasItem in interface AbstractEmbedding
      Parameters:
      item - the item to test
      Returns:
      true if the item is in the embedding