An implementation of PyTorch's EmbeddingBag operator.
An implementation of PyTorch's EmbeddingBag operator. EmbeddingBags serve 2 primary purposes, depending on your use case:
EmbeddingBag allows 2 types of input - either a 2D matrix of indices is provided, or a 1D vector of indices is provided along with a 1D vector of "offsets." In either case, the goal is to take a chunk/vector of indices, do a bunch of lookups into an embedding, and then reduce these embeddings (sum, mean, max, etc) to a single vector per-row.
If offsets is defined, then it must be a 1D vector of offsets and input must be a 1D vector of long indices. The first offset must be 0 and no offset can be larger than the length of input. The "chunks" that are then reduced are defined by the any 2 adjacent offsets, defining the start (inclusive) and end (exclusive) of the range of indices in input. The last offset is implied to extend to the end of input. For example, if offsets is [0, 2, 3, 6], and input is [0, 1, 5, 3, 9, 2, 1, 2], then the chunks that must be reduced are [ [0, 1], [5], [3, 9, 2], [1, 2] ]. Each chunk is dynamically sized, which is why this type of input can't be provided as a 2D tensor instead.
If offsets
are not defined, then the input must be a 2D matrix of long
indices. Each row in the matrix is a chunk that must be reduced.
If perIndexWeights
is defined, then it must have the same shape as
input
and mode
must be Sum
. The perIndexWeights
are used to scale
the embeddings returned by each index we lookup in data
, which is why
each weight must correspond to exactly 1 index in input
.
a 2-D matrix of floats, each row is an embedding
the type of aggregation to perform on the embeddings
a rank-2 or rank-1 (if offsets is defined) tensor of long indices
an optional rank-1 tensor of long offsets into input
an optional set of weights to apply to each lookup from input
a 2-D matrix of floats, the aggregated embeddings for the batch