Class BertPretrainingBlock

java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractBlock
ai.djl.nn.transformer.BertPretrainingBlock
All Implemented Interfaces:
Block

public class BertPretrainingBlock extends AbstractBlock
Creates a block that performs all bert pretraining tasks (next sentence and masked language).
  • Constructor Details

    • BertPretrainingBlock

      public BertPretrainingBlock(BertBlock.Builder builder)
      Creates a new Bert pretraining block fitting the given bert configuration.
      Parameters:
      builder - a builder with a bert configuration
  • Method Details

    • initializeChildBlocks

      public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes)
      Initializes the Child blocks of this block. You need to override this method if your subclass has child blocks. Used to determine the correct input shapes for child blocks based on the requested input shape for this block.
      Overrides:
      initializeChildBlocks in class AbstractBaseBlock
      Parameters:
      manager - the manager to use for initialization
      dataType - the requested data type
      inputShapes - the expected input shapes for this block
    • forwardInternal

      protected NDList forwardInternal(ParameterStore ps, NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class AbstractBaseBlock
      Parameters:
      ps - the parameter store
      inputs - the input NDList
      training - true for a training forward pass
      params - optional parameters
      Returns:
      the output of the forward pass
    • getOutputShapes

      public Shape[] getOutputShapes(Shape[] inputShapes)
      Returns the expected output shapes of the block for the specified input shapes.
      Parameters:
      inputShapes - the shapes of the inputs
      Returns:
      the expected output shapes of the block