Object

com.stripe.agate.ops

EmbeddingBag

Related Doc: package ops

Permalink

object EmbeddingBag

Linear Supertypes
AnyRef, Any
Ordering
  1. Alphabetic
  2. By Inheritance
Inherited
  1. EmbeddingBag
  2. AnyRef
  3. Any
  1. Hide All
  2. Show All
Visibility
  1. Public
  2. All

Type Members

  1. sealed abstract class Mode extends AnyRef

    Permalink

Value Members

  1. final def !=(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  2. final def ##(): Int

    Permalink
    Definition Classes
    AnyRef → Any
  3. final def ==(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  4. object Mode

    Permalink
  5. def apply(data: Tensor[Float32.type], mode: Mode, input: Tensor[Int64.type], offsets: Option[Tensor[Int64.type]], perIndexWeights: Option[Tensor[Float32.type]]): Try[Tensor[Float32.type]]

    Permalink

    An implementation of PyTorch's EmbeddingBag operator.

    An implementation of PyTorch's EmbeddingBag operator. EmbeddingBags serve 2 primary purposes, depending on your use case:

    • EmbeddingBag is an operator that optimizes a group of N categorical embeddings that are then summed/averaged by instead training a single, large embedding for all categoricals and performing the sum in-place. This greatly reduces training time.
    • EmbeddingBag is an operator that allows dynamically-sized number of lookups into an embedding space, and summing all the results together. For example, looking up arbitrarily sized tokenized sentences to produce a mean embedding, such as would be done by something like FastText.

    EmbeddingBag allows 2 types of input - either a 2D matrix of indices is provided, or a 1D vector of indices is provided along with a 1D vector of "offsets." In either case, the goal is to take a chunk/vector of indices, do a bunch of lookups into an embedding, and then reduce these embeddings (sum, mean, max, etc) to a single vector per-row.

    If offsets is defined, then it must be a 1D vector of offsets and input must be a 1D vector of long indices. The first offset must be 0 and no offset can be larger than the length of input. The "chunks" that are then reduced are defined by the any 2 adjacent offsets, defining the start (inclusive) and end (exclusive) of the range of indices in input. The last offset is implied to extend to the end of input. For example, if offsets is [0, 2, 3, 6], and input is [0, 1, 5, 3, 9, 2, 1, 2], then the chunks that must be reduced are [ [0, 1], [5], [3, 9, 2], [1, 2] ]. Each chunk is dynamically sized, which is why this type of input can't be provided as a 2D tensor instead.

    If offsets are not defined, then the input must be a 2D matrix of long indices. Each row in the matrix is a chunk that must be reduced.

    If perIndexWeights is defined, then it must have the same shape as input and mode must be Sum. The perIndexWeights are used to scale the embeddings returned by each index we lookup in data, which is why each weight must correspond to exactly 1 index in input.

    data

    a 2-D matrix of floats, each row is an embedding

    mode

    the type of aggregation to perform on the embeddings

    input

    a rank-2 or rank-1 (if offsets is defined) tensor of long indices

    offsets

    an optional rank-1 tensor of long offsets into input

    perIndexWeights

    an optional set of weights to apply to each lookup from input

    returns

    a 2-D matrix of floats, the aggregated embeddings for the batch

  6. final def asInstanceOf[T0]: T0

    Permalink
    Definition Classes
    Any
  7. def clone(): AnyRef

    Permalink
    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  8. final def eq(arg0: AnyRef): Boolean

    Permalink
    Definition Classes
    AnyRef
  9. def equals(arg0: Any): Boolean

    Permalink
    Definition Classes
    AnyRef → Any
  10. def finalize(): Unit

    Permalink
    Attributes
    protected[java.lang]
    Definition Classes
    AnyRef
    Annotations
    @throws( classOf[java.lang.Throwable] )
  11. final def getClass(): Class[_]

    Permalink
    Definition Classes
    AnyRef → Any
  12. def hashCode(): Int

    Permalink
    Definition Classes
    AnyRef → Any
  13. final def isInstanceOf[T0]: Boolean

    Permalink
    Definition Classes
    Any
  14. final def ne(arg0: AnyRef): Boolean

    Permalink
    Definition Classes
    AnyRef
  15. final def notify(): Unit

    Permalink
    Definition Classes
    AnyRef
  16. final def notifyAll(): Unit

    Permalink
    Definition Classes
    AnyRef
  17. final def synchronized[T0](arg0: ⇒ T0): T0

    Permalink
    Definition Classes
    AnyRef
  18. def toString(): String

    Permalink
    Definition Classes
    AnyRef → Any
  19. final def wait(): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  20. final def wait(arg0: Long, arg1: Int): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )
  21. final def wait(arg0: Long): Unit

    Permalink
    Definition Classes
    AnyRef
    Annotations
    @throws( ... )

Inherited from AnyRef

Inherited from Any

Ungrouped