Class BertNextSentenceBlock

java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.transformer.BertNextSentenceBlock
All Implemented Interfaces:
Block

public class BertNextSentenceBlock extends AbstractBlock
Block to perform the Bert next-sentence-prediction task.
  • Constructor Details

    • BertNextSentenceBlock

      public BertNextSentenceBlock()
      Creates a next sentence block.
  • Method Details

    • initializeChildBlocks

      public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
      Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
      Overrides:
      initializeChildBlocks in class AbstractBaseBlock
      Parameters:
      manager - the manager to use for initialization
      dataType - the requested data type
      inputShapes - the expected input shapes for this block
    • forwardInternal

      protected NDList forwardInternal(ParameterStore ps, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      ps - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputShapes)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputShapes - the shapes of the inputs
      Returns:
      the expected output shapes of the block