public class BertPretrainingLoss extends AbstractCompositeLoss
components
totalInstances
Constructor and Description |
---|
BertPretrainingLoss()
Creates a loss combining the next sentence and masked language loss for bert pretraining.
|
Modifier and Type | Method and Description |
---|---|
BertMaskedLanguageModelLoss |
getBertMaskedLanguageModelLoss()
gets BertMaskedLanguageModelLoss.
|
BertNextSentenceLoss |
getBertNextSentenceLoss()
gets BertNextSentenceLoss.
|
protected ai.djl.util.Pair<NDList,NDList> |
inputForComponent(int componentIndex,
NDList labels,
NDList predictions)
Returns the inputs to computing the loss for a component loss.
|
addAccumulator, evaluate, getAccumulator, getComponents, resetAccumulator, updateAccumulator
hingeLoss, hingeLoss, hingeLoss, l1Loss, l1Loss, l1Loss, l2Loss, l2Loss, l2Loss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, maskedSoftmaxCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, sigmoidBinaryCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss, softmaxCrossEntropyLoss
checkLabelShapes, checkLabelShapes, getName
public BertPretrainingLoss()
protected ai.djl.util.Pair<NDList,NDList> inputForComponent(int componentIndex, NDList labels, NDList predictions)
AbstractCompositeLoss
inputForComponent
in class AbstractCompositeLoss
componentIndex
- the index of the component losslabels
- the label input to the composite losspredictions
- the predictions input to the composite losspublic BertNextSentenceLoss getBertNextSentenceLoss()
public BertMaskedLanguageModelLoss getBertMaskedLanguageModelLoss()