Class PtSymbolBlock

  • All Implemented Interfaces:
    ai.djl.nn.Block, ai.djl.nn.SymbolBlock, java.lang.AutoCloseable

    public class PtSymbolBlock
    extends ai.djl.nn.AbstractSymbolBlock
    implements java.lang.AutoCloseable
    PtSymbolBlock is the PyTorch implementation of SymbolBlock.

    You can create a PtSymbolBlock using Model.load(java.nio.file.Path, String).

    • Field Summary

      • Fields inherited from class ai.djl.nn.AbstractBaseBlock

        inputNames, inputShapes, outputDataTypes, version
    • Method Summary

      All Methods Instance Methods Concrete Methods 
      Modifier and Type Method Description
      void close()
      ai.djl.util.PairList<java.lang.String,​ai.djl.ndarray.types.Shape> describeInput()
      ai.djl.util.PairList<java.lang.String,​ai.djl.ndarray.types.Shape> describeOutput()
      IValue forward​(IValue... inputs)
      Runs the forward of this PyTorch module.
      protected ai.djl.ndarray.NDList forwardInternal​(ai.djl.training.ParameterStore parameterStore, ai.djl.ndarray.NDList inputs, boolean training, ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
      ai.djl.nn.ParameterList getDirectParameters()
      java.lang.Long getHandle()
      Get the native PyTorch model pointer.
      ai.djl.ndarray.types.Shape[] getOutputShapes​(ai.djl.ndarray.types.Shape[] inputShapes)
      ai.djl.ndarray.types.Shape[] getOutputShapes​(ai.djl.ndarray.types.Shape[] inputShapes, ai.djl.ndarray.types.DataType[] dataTypes)
      void loadParameters​(ai.djl.ndarray.NDManager manager, java.io.DataInputStream is)
      void saveParameters​(java.io.DataOutputStream os)
      • Methods inherited from class ai.djl.nn.AbstractSymbolBlock

        getChildren
      • Methods inherited from class ai.djl.nn.AbstractBaseBlock

        beforeInitialize, cast, clear, forward, forward, forwardInternal, getInputShapes, getOutputDataTypes, getParameters, initialize, initializeChildBlocks, isInitialized, loadMetadata, prepare, readInputShapes, saveInputShapes, saveMetadata, setInitializer, setInitializer, setInitializer, toString
      • Methods inherited from class java.lang.Object

        clone, equals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, wait
      • Methods inherited from interface ai.djl.nn.Block

        cast, clear, forward, forward, forward, freezeParameters, getInputShapes, getOutputDataTypes, getParameters, initialize, isInitialized, setInitializer, setInitializer, setInitializer
      • Methods inherited from interface ai.djl.nn.SymbolBlock

        removeLastBlock
    • Constructor Detail

      • PtSymbolBlock

        public PtSymbolBlock​(PtNDManager manager,
                             long handle)
        Constructs a PtSymbolBlock.

        You can create a PtSymbolBlock using Model.load(java.nio.file.Path, String).

        Parameters:
        manager - the manager to use for the block
        handle - the module handle
      • PtSymbolBlock

        public PtSymbolBlock​(PtNDManager manager)
        Constructs an Empty PtSymbolBlock.
        Parameters:
        manager - the manager to use for the block
    • Method Detail

      • close

        public void close()
        Specified by:
        close in interface java.lang.AutoCloseable
      • forward

        public IValue forward​(IValue... inputs)
        Runs the forward of this PyTorch module.
        Parameters:
        inputs - the input IValue
        Returns:
        the result IValue
      • forwardInternal

        protected ai.djl.ndarray.NDList forwardInternal​(ai.djl.training.ParameterStore parameterStore,
                                                        ai.djl.ndarray.NDList inputs,
                                                        boolean training,
                                                        ai.djl.util.PairList<java.lang.String,​java.lang.Object> params)
        Specified by:
        forwardInternal in class ai.djl.nn.AbstractBaseBlock
      • describeInput

        public ai.djl.util.PairList<java.lang.String,​ai.djl.ndarray.types.Shape> describeInput()
        Specified by:
        describeInput in interface ai.djl.nn.Block
        Overrides:
        describeInput in class ai.djl.nn.AbstractBaseBlock
      • getDirectParameters

        public ai.djl.nn.ParameterList getDirectParameters()
        Specified by:
        getDirectParameters in interface ai.djl.nn.Block
      • describeOutput

        public ai.djl.util.PairList<java.lang.String,​ai.djl.ndarray.types.Shape> describeOutput()
        Specified by:
        describeOutput in interface ai.djl.nn.SymbolBlock
      • getOutputShapes

        public ai.djl.ndarray.types.Shape[] getOutputShapes​(ai.djl.ndarray.types.Shape[] inputShapes)
        Specified by:
        getOutputShapes in interface ai.djl.nn.Block
        Overrides:
        getOutputShapes in class ai.djl.nn.AbstractSymbolBlock
      • getOutputShapes

        public ai.djl.ndarray.types.Shape[] getOutputShapes​(ai.djl.ndarray.types.Shape[] inputShapes,
                                                            ai.djl.ndarray.types.DataType[] dataTypes)
        Specified by:
        getOutputShapes in interface ai.djl.nn.Block
      • saveParameters

        public void saveParameters​(java.io.DataOutputStream os)
                            throws java.io.IOException
        Specified by:
        saveParameters in interface ai.djl.nn.Block
        Overrides:
        saveParameters in class ai.djl.nn.AbstractBaseBlock
        Throws:
        java.io.IOException
      • loadParameters

        public void loadParameters​(ai.djl.ndarray.NDManager manager,
                                   java.io.DataInputStream is)
                            throws java.io.IOException,
                                   ai.djl.MalformedModelException
        Specified by:
        loadParameters in interface ai.djl.nn.Block
        Overrides:
        loadParameters in class ai.djl.nn.AbstractBaseBlock
        Throws:
        java.io.IOException
        ai.djl.MalformedModelException
      • getHandle

        public java.lang.Long getHandle()
        Get the native PyTorch model pointer.
        Returns:
        the pointer