BertLossInput

case class BertLossInput(input: BertPretrainInput, maskedLanguageModelTarget: STen, wholeSentenceTarget: STen)

Input to BertLoss module

  • input: feature data, see documentation of BertPretrainInput
  • maskedLanguageModelTarget: long tensor of (batch size, masked positions (variable)). Values are the true tokens masked out at the positions in input.positions
  • wholeSentenceTarget: float tensor of size (batch size). Values are truth targets for the whole sentence loss which is a BCEWithLogitLoss. Values are floats in [0,1].
Companion:
object
trait Serializable
trait Product
trait Equals
class Object
trait Matchable
class Any

Value members

Inherited methods

def productElementNames: Iterator[String]
Inherited from:
Product
def productIterator: Iterator[Any]
Inherited from:
Product