Class IdEmbedding

  • All Implemented Interfaces:
    Block

    public final class IdEmbedding
    extends AbstractBlock
    An Embedding from integer ids to float vectors. Output shape is the input shape + one dimension for the embedding. E.g. If input shape is (-1, 128), embedding size is 1024, then the output shape is (-1, 128, 1024)
    • Method Detail

      • getOutputShapes

        public Shape[] getOutputShapes​(Shape[] inputShapes)
        Returns the expected output shapes of the block for the specified input shapes.
        Parameters:
        inputShapes - the shapes of the inputs
        Returns:
        the expected output shapes of the block
      • probabilities

        public NDArray probabilities​(ParameterStore parameterStore,
                                     NDArray input,
                                     boolean training)
        Turns an array of embeddings of shape (d0 ... dN, E) into an array of log probabilities of shape (d0 ... dN, D) that shows the probability distribution that a given embedding corresponds to an entry in the internal embedding table.
        Parameters:
        parameterStore - the parameters store
        input - the embeddings to create log probabilities for
        training - true for a training forward pass
        Returns:
        log probabilities for each embedding
      • getValue

        public NDArray getValue​(ParameterStore ps,
                                Device device,
                                boolean training)
        Quick hack for bert model to acces embedding table, replace by a proper function to calculate raw logits from embeddings. TODO: replace by function to get logits
        Parameters:
        ps - the parameter store
        device - device to get internal table for
        training - true for a training forward pass
        Returns:
        this embedding table as an array on the given device
      • initializeChildBlocks

        public void initializeChildBlocks​(NDManager manager,
                                          DataType dataType,
                                          Shape... inputShapes)
        Description copied from class: AbstractBlock
        Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
        Overrides:
        initializeChildBlocks in class AbstractBlock
        Parameters:
        manager - the manager to use for initialization
        dataType - the requested data type
        inputShapes - the expected input shapes for this block