Class PtSymbolBlock

java.lang.Object
ai.djl.nn.AbstractBaseBlock
ai.djl.nn.AbstractSymbolBlock
ai.djl.pytorch.engine.PtSymbolBlock
All Implemented Interfaces:
ai.djl.nn.Block, ai.djl.nn.SymbolBlock, AutoCloseable

public class PtSymbolBlock extends ai.djl.nn.AbstractSymbolBlock implements AutoCloseable
PtSymbolBlock is the PyTorch implementation of SymbolBlock.

You can create a PtSymbolBlock using Model.load(java.nio.file.Path, String).

  • Field Summary

    Fields inherited from class ai.djl.nn.AbstractBaseBlock

    inputNames, inputShapes, outputDataTypes, version
  • Constructor Summary

    Constructors
    Constructor
    Description
    Constructs an Empty PtSymbolBlock.
    PtSymbolBlock(PtNDManager manager, long handle)
    Constructs a PtSymbolBlock.
  • Method Summary

    Modifier and Type
    Method
    Description
    void
    ai.djl.util.PairList<String,ai.djl.ndarray.types.Shape>
    ai.djl.util.PairList<String,ai.djl.ndarray.types.Shape>
    forward(IValue... inputs)
    Runs the forward of this PyTorch module.
    protected ai.djl.ndarray.NDList
    forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
    ai.djl.nn.ParameterList
    Get the native PyTorch model pointer.
    ai.djl.ndarray.types.Shape[]
    getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
    ai.djl.ndarray.types.Shape[]
    getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes)
    void
    loadParameters(ai.djl.ndarray.NDManager manager, DataInputStream is)
    void

    Methods inherited from class ai.djl.nn.AbstractSymbolBlock

    getChildren

    Methods inherited from class ai.djl.nn.AbstractBaseBlock

    beforeInitialize, cast, clear, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait

    Methods inherited from interface ai.djl.nn.Block

    cast, clear, forward, forward, forward, freezeParameters, freezeParameters, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, setInitializer, setInitializer, setInitializer

    Methods inherited from interface ai.djl.nn.SymbolBlock

    removeLastBlock
  • Constructor Details

    • PtSymbolBlock

      public PtSymbolBlock(PtNDManager manager, long handle)
      Constructs a PtSymbolBlock.

      You can create a PtSymbolBlock using Model.load(java.nio.file.Path, String).

      Parameters:
      manager - the manager to use for the block
      handle - the module handle
    • PtSymbolBlock

      public PtSymbolBlock(PtNDManager manager)
      Constructs an Empty PtSymbolBlock.
      Parameters:
      manager - the manager to use for the block
  • Method Details

    • close

      public void close()
      Specified by:
      close in interface AutoCloseable
    • forward

      public IValue forward(IValue... inputs)
      Runs the forward of this PyTorch module.
      Parameters:
      inputs - the input IValue
      Returns:
      the result IValue
    • forwardInternal

      protected ai.djl.ndarray.NDList forwardInternal(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<String,Object> params)
      Specified by:
      forwardInternal in class ai.djl.nn.AbstractBaseBlock
    • describeInput

      public ai.djl.util.PairList<String,ai.djl.ndarray.types.Shape> describeInput()
      Specified by:
      describeInput in interface ai.djl.nn.Block
      Overrides:
      describeInput in class ai.djl.nn.AbstractBaseBlock
    • getDirectParameters

      public ai.djl.nn.ParameterList getDirectParameters()
      Specified by:
      getDirectParameters in interface ai.djl.nn.Block
    • describeOutput

      public ai.djl.util.PairList<String,ai.djl.ndarray.types.Shape> describeOutput()
      Specified by:
      describeOutput in interface ai.djl.nn.SymbolBlock
    • getOutputShapes

      public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes)
      Specified by:
      getOutputShapes in interface ai.djl.nn.Block
      Overrides:
      getOutputShapes in class ai.djl.nn.AbstractSymbolBlock
    • getOutputShapes

      public ai.djl.ndarray.types.Shape[] getOutputShapes(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes)
      Specified by:
      getOutputShapes in interface ai.djl.nn.Block
    • saveParameters

      public void saveParameters(DataOutputStream os) throws IOException
      Specified by:
      saveParameters in interface ai.djl.nn.Block
      Overrides:
      saveParameters in class ai.djl.nn.AbstractBaseBlock
      Throws:
      IOException
    • loadParameters

      public void loadParameters(ai.djl.ndarray.NDManager manager, DataInputStream is) throws IOException, ai.djl.MalformedModelException
      Specified by:
      loadParameters in interface ai.djl.nn.Block
      Overrides:
      loadParameters in class ai.djl.nn.AbstractBaseBlock
      Throws:
      IOException
      ai.djl.MalformedModelException
    • getHandle

      public Long getHandle()
      Get the native PyTorch model pointer.
      Returns:
      the pointer