Class JniUtils


  • public final class JniUtils
    extends java.lang.Object
    A class containing utilities to interact with the PyTorch Engine's Java Native Interface (JNI) layer.
    • Method Detail

      • isGradMode

        public static boolean isGradMode()
      • setGradMode

        public static void setGradMode​(boolean enable)
      • getNumInteropThreads

        public static int getNumInteropThreads()
      • getNumThreads

        public static int getNumThreads()
      • setNumInteropThreads

        public static void setNumInteropThreads​(int threads)
      • setNumThreads

        public static void setNumThreads​(int threads)
      • setBenchmarkCuDNN

        public static void setBenchmarkCuDNN​(boolean enable)
      • getFeatures

        public static java.util.Set<java.lang.String> getFeatures()
      • setSeed

        public static void setSeed​(long seed)
      • startProfile

        public static void startProfile​(boolean useCuda,
                                        boolean recordShape,
                                        boolean profileMemory)
        Calls this method to start profile the area you are interested in.

        Example usage

              JniUtils.startProfile(false, true, true);
              Predictor.predict(img);
              JniUtils.stopProfile(outputFile)
         
        Parameters:
        useCuda - Enables timing of CUDA events as well using the cudaEvent API.
        recordShape - If shapes recording is set, information about input dimensions will be collected
        profileMemory - Whether to report memory usage
      • stopProfile

        public static void stopProfile​(java.lang.String outputFile)
      • createNdFromByteBuffer

        public static PtNDArray createNdFromByteBuffer​(PtNDManager manager,
                                                       java.nio.ByteBuffer data,
                                                       ai.djl.ndarray.types.Shape shape,
                                                       ai.djl.ndarray.types.DataType dType,
                                                       ai.djl.ndarray.types.SparseFormat fmt,
                                                       ai.djl.Device device)
      • emptyCudaCache

        public static void emptyCudaCache()
      • createEmptyNdArray

        public static PtNDArray createEmptyNdArray​(PtNDManager manager,
                                                   ai.djl.ndarray.types.Shape shape,
                                                   ai.djl.ndarray.types.DataType dType,
                                                   ai.djl.Device device,
                                                   ai.djl.ndarray.types.SparseFormat fmt)
      • createZerosNdArray

        public static PtNDArray createZerosNdArray​(PtNDManager manager,
                                                   ai.djl.ndarray.types.Shape shape,
                                                   ai.djl.ndarray.types.DataType dType,
                                                   ai.djl.Device device,
                                                   ai.djl.ndarray.types.SparseFormat fmt)
      • createOnesNdArray

        public static PtNDArray createOnesNdArray​(PtNDManager manager,
                                                  ai.djl.ndarray.types.Shape shape,
                                                  ai.djl.ndarray.types.DataType dType,
                                                  ai.djl.Device device,
                                                  ai.djl.ndarray.types.SparseFormat fmt)
      • full

        public static PtNDArray full​(PtNDManager manager,
                                     ai.djl.ndarray.types.Shape shape,
                                     double fillValue,
                                     ai.djl.ndarray.types.DataType dType,
                                     ai.djl.Device device,
                                     ai.djl.ndarray.types.SparseFormat fmt)
      • zerosLike

        public static PtNDArray zerosLike​(PtNDArray array,
                                          ai.djl.ndarray.types.DataType dType,
                                          ai.djl.Device device,
                                          ai.djl.ndarray.types.SparseFormat fmt)
      • onesLike

        public static PtNDArray onesLike​(PtNDArray array,
                                         ai.djl.ndarray.types.DataType dType,
                                         ai.djl.Device device,
                                         ai.djl.ndarray.types.SparseFormat fmt)
      • arange

        public static PtNDArray arange​(PtNDManager manager,
                                       float start,
                                       float stop,
                                       float step,
                                       ai.djl.ndarray.types.DataType dType,
                                       ai.djl.Device device,
                                       ai.djl.ndarray.types.SparseFormat fmt)
      • linspace

        public static PtNDArray linspace​(PtNDManager manager,
                                         float start,
                                         float stop,
                                         int step,
                                         ai.djl.ndarray.types.DataType dType,
                                         ai.djl.Device device,
                                         ai.djl.ndarray.types.SparseFormat fmt)
      • createSparseCoo

        public static PtNDArray createSparseCoo​(PtNDArray indices,
                                                PtNDArray values,
                                                ai.djl.ndarray.types.Shape shape)
      • to

        public static PtNDArray to​(PtNDArray ndArray,
                                   ai.djl.ndarray.types.DataType dataType,
                                   ai.djl.Device device)
      • broadcast

        public static PtNDArray broadcast​(PtNDArray ndArray,
                                          ai.djl.ndarray.types.Shape shape)
      • slice

        public static PtNDArray slice​(PtNDArray ndArray,
                                      long dim,
                                      long start,
                                      long stop,
                                      long step)
      • index

        public static PtNDArray index​(PtNDArray ndArray,
                                      long[] minIndices,
                                      long[] maxIndices,
                                      long[] stepIndices,
                                      PtNDManager manager)
      • indexAdvPut

        public static void indexAdvPut​(PtNDArray ndArray,
                                       ai.djl.ndarray.index.NDIndex index,
                                       PtNDArray data)
      • indexSet

        public static void indexSet​(PtNDArray ndArray,
                                    PtNDArray value,
                                    long[] minIndices,
                                    long[] maxIndices,
                                    long[] stepIndices)
      • set

        public static void set​(PtNDArray self,
                               java.nio.ByteBuffer data)
      • softmax

        public static PtNDArray softmax​(PtNDArray ndArray,
                                        long dim,
                                        ai.djl.ndarray.types.DataType dTpe)
      • logSoftmax

        public static PtNDArray logSoftmax​(PtNDArray ndArray,
                                           long dim,
                                           ai.djl.ndarray.types.DataType dTpe)
      • argMax

        public static PtNDArray argMax​(PtNDArray ndArray,
                                       long dim,
                                       boolean keepDim)
      • topK

        public static ai.djl.ndarray.NDList topK​(PtNDArray ndArray,
                                                 long k,
                                                 long axis,
                                                 boolean largest,
                                                 boolean sorted)
      • argMin

        public static PtNDArray argMin​(PtNDArray ndArray,
                                       long dim,
                                       boolean keepDim)
      • argSort

        public static PtNDArray argSort​(PtNDArray ndArray,
                                        long dim,
                                        boolean keepDim)
      • sort

        public static PtNDArray sort​(PtNDArray ndArray,
                                     long dim,
                                     boolean descending)
      • transpose

        public static PtNDArray transpose​(PtNDArray ndArray,
                                          long dim1,
                                          long dim2)
      • contentEqual

        public static boolean contentEqual​(PtNDArray ndArray1,
                                           PtNDArray ndArray2)
      • signi

        public static void signi​(PtNDArray ndArray)
      • median

        public static ai.djl.ndarray.NDList median​(PtNDArray ndArray,
                                                   long dim,
                                                   boolean keepDim)
      • sum

        public static PtNDArray sum​(PtNDArray ndArray,
                                    long[] dims,
                                    boolean keepDim)
      • cumProd

        public static PtNDArray cumProd​(PtNDArray ndArray,
                                        long dim,
                                        ai.djl.ndarray.types.DataType dataType)
      • oneHot

        public static PtNDArray oneHot​(PtNDArray ndArray,
                                       int depth,
                                       ai.djl.ndarray.types.DataType dataType)
      • split

        public static ai.djl.ndarray.NDList split​(PtNDArray ndArray,
                                                  long size,
                                                  long axis)
      • split

        public static ai.djl.ndarray.NDList split​(PtNDArray ndArray,
                                                  long[] indices,
                                                  long axis)
      • unique

        public static ai.djl.ndarray.NDList unique​(PtNDArray ndArray,
                                                   java.lang.Integer dim,
                                                   boolean sorted,
                                                   boolean returnInverse,
                                                   boolean returnCounts)
      • flatten

        public static PtNDArray flatten​(PtNDArray ndArray,
                                        long startDim,
                                        long endDim)
      • stft

        public static PtNDArray stft​(PtNDArray ndArray,
                                     long nFft,
                                     long hopLength,
                                     PtNDArray window,
                                     boolean center,
                                     boolean normalize,
                                     boolean returnComplex)
      • clip

        public static PtNDArray clip​(PtNDArray ndArray,
                                     java.lang.Number min,
                                     java.lang.Number max)
      • negi

        public static void negi​(PtNDArray ndArray)
      • randint

        public static PtNDArray randint​(PtNDManager manager,
                                        long low,
                                        long high,
                                        ai.djl.ndarray.types.Shape size,
                                        ai.djl.ndarray.types.DataType dataType,
                                        ai.djl.Device device)
      • randperm

        public static PtNDArray randperm​(PtNDManager manager,
                                         long n,
                                         ai.djl.ndarray.types.DataType dataType,
                                         ai.djl.Device device)
      • normal

        public static PtNDArray normal​(PtNDManager manager,
                                       double mean,
                                       double std,
                                       ai.djl.ndarray.types.Shape size,
                                       ai.djl.ndarray.types.DataType dataType,
                                       ai.djl.Device device)
      • uniform

        public static PtNDArray uniform​(PtNDManager manager,
                                        double low,
                                        double high,
                                        ai.djl.ndarray.types.Shape size,
                                        ai.djl.ndarray.types.DataType dataType,
                                        ai.djl.Device device)
      • eye

        public static PtNDArray eye​(PtNDManager manager,
                                    int n,
                                    int m,
                                    ai.djl.ndarray.types.DataType dataType,
                                    ai.djl.Device device,
                                    ai.djl.ndarray.types.SparseFormat fmt)
      • hannWindow

        public static PtNDArray hannWindow​(PtNDManager manager,
                                           long numPoints,
                                           boolean periodic,
                                           ai.djl.Device device)
      • interpolate

        public static PtNDArray interpolate​(PtNDArray ndArray,
                                            long[] size,
                                            int mode,
                                            boolean alignCorners)
      • leakyRelu

        public static PtNDArray leakyRelu​(PtNDArray ndArray,
                                          double negativeSlope)
      • convolution

        public static PtNDArray convolution​(PtNDArray ndArray,
                                            PtNDArray weight,
                                            PtNDArray bias,
                                            ai.djl.ndarray.types.Shape stride,
                                            ai.djl.ndarray.types.Shape padding,
                                            ai.djl.ndarray.types.Shape dilation,
                                            int groups)
      • normalize

        public static PtNDArray normalize​(PtNDArray ndArray,
                                          double p,
                                          long dim,
                                          double eps)
      • dropout

        public static PtNDArray dropout​(PtNDArray ndArray,
                                        double prob,
                                        boolean training)
      • rnn

        public static ai.djl.ndarray.NDList rnn​(PtNDArray input,
                                                PtNDArray hx,
                                                ai.djl.ndarray.NDList params,
                                                boolean hasBiases,
                                                int numLayers,
                                                ai.djl.nn.recurrent.RNN.Activation activation,
                                                double dropRate,
                                                boolean training,
                                                boolean bidirectional,
                                                boolean batchFirst)
      • gru

        public static ai.djl.ndarray.NDList gru​(PtNDArray input,
                                                PtNDArray hx,
                                                ai.djl.ndarray.NDList params,
                                                boolean hasBiases,
                                                int numLayers,
                                                double dropRate,
                                                boolean training,
                                                boolean bidirectional,
                                                boolean batchFirst)
      • lstm

        public static ai.djl.ndarray.NDList lstm​(PtNDArray input,
                                                 ai.djl.ndarray.NDList hx,
                                                 ai.djl.ndarray.NDList params,
                                                 boolean hasBiases,
                                                 int numLayers,
                                                 double dropRate,
                                                 boolean training,
                                                 boolean bidirectional,
                                                 boolean batchFirst)
      • avgPool

        public static PtNDArray avgPool​(PtNDArray ndArray,
                                        ai.djl.ndarray.types.Shape kernelSize,
                                        ai.djl.ndarray.types.Shape stride,
                                        ai.djl.ndarray.types.Shape padding,
                                        boolean ceilMode,
                                        boolean countIncludePad)
      • maxPool

        public static PtNDArray maxPool​(PtNDArray ndArray,
                                        ai.djl.ndarray.types.Shape kernelSize,
                                        ai.djl.ndarray.types.Shape stride,
                                        ai.djl.ndarray.types.Shape padding,
                                        boolean ceilMode)
      • adaptiveMaxPool

        public static PtNDArray adaptiveMaxPool​(PtNDArray ndArray,
                                                ai.djl.ndarray.types.Shape outputSize)
      • adaptiveAvgPool

        public static PtNDArray adaptiveAvgPool​(PtNDArray ndArray,
                                                ai.djl.ndarray.types.Shape outputSize)
      • lpPool

        public static PtNDArray lpPool​(PtNDArray ndArray,
                                       double normType,
                                       ai.djl.ndarray.types.Shape kernelSize,
                                       ai.djl.ndarray.types.Shape stride,
                                       boolean ceilMode)
      • getDataType

        public static ai.djl.ndarray.types.DataType getDataType​(PtNDArray ndArray)
      • getDevice

        public static ai.djl.Device getDevice​(PtNDArray ndArray)
      • getSparseFormat

        public static ai.djl.ndarray.types.SparseFormat getSparseFormat​(PtNDArray ndArray)
      • getShape

        public static ai.djl.ndarray.types.Shape getShape​(PtNDArray ndArray)
      • getByteBuffer

        public static java.nio.ByteBuffer getByteBuffer​(PtNDArray ndArray)
      • deleteNDArray

        public static void deleteNDArray​(long handle)
      • requiresGrad

        public static boolean requiresGrad​(PtNDArray ndArray)
      • getGradientFunctionNames

        public static java.lang.String getGradientFunctionNames​(PtNDArray ndArray)
      • attachGradient

        public static void attachGradient​(PtNDArray ndArray,
                                          boolean requiresGrad)
      • backward

        public static void backward​(PtNDArray ndArray,
                                    PtNDArray gradNd,
                                    boolean keepGraph,
                                    boolean createGraph)
      • deleteModule

        public static void deleteModule​(long pointer)
      • setGraphExecutorOptimize

        public static void setGraphExecutorOptimize​(boolean enabled)
      • loadModule

        public static PtSymbolBlock loadModule​(PtNDManager manager,
                                               java.nio.file.Path path,
                                               boolean mapLocation,
                                               java.lang.String[] extraFileKeys,
                                               java.lang.String[] extraFileValues,
                                               boolean trainParam)
      • loadModule

        public static PtSymbolBlock loadModule​(PtNDManager manager,
                                               java.io.InputStream is,
                                               boolean mapLocation,
                                               boolean hasSize)
                                        throws java.io.IOException
        Throws:
        java.io.IOException
      • loadModuleHandle

        public static long loadModuleHandle​(java.io.InputStream is,
                                            ai.djl.Device device,
                                            boolean mapLocation,
                                            boolean hasSize)
                                     throws java.io.IOException
        Throws:
        java.io.IOException
      • writeModule

        public static void writeModule​(PtSymbolBlock block,
                                       java.io.OutputStream os,
                                       boolean writeSize)
      • getMethodNames

        public static java.lang.String[] getMethodNames​(PtSymbolBlock block)
      • enableInferenceMode

        public static void enableInferenceMode​(PtSymbolBlock block)
      • enableTrainingMode

        public static void enableTrainingMode​(PtSymbolBlock block)
      • zeroGrad

        public static void zeroGrad​(PtNDArray weight)
      • adamUpdate

        public static void adamUpdate​(PtNDArray weight,
                                      PtNDArray grad,
                                      PtNDArray mean,
                                      PtNDArray variance,
                                      float lr,
                                      float learningRateBiasCorrection,
                                      float wd,
                                      float rescaleGrad,
                                      float clipGrad,
                                      float beta1,
                                      float beta2,
                                      float eps,
                                      boolean adamw)
      • sgdUpdate

        public static void sgdUpdate​(PtNDArray weight,
                                     PtNDArray grad,
                                     PtNDArray state,
                                     float lr,
                                     float wd,
                                     float rescaleGrad,
                                     float clipGrad,
                                     float momentum)
      • getLayout

        public static int getLayout​(PtNDArray array)
      • norm

        public static PtNDArray norm​(PtNDArray ndArray,
                                     int ord,
                                     int[] axes,
                                     boolean keepDims)