Class JniUtils

java.lang.Object
ai.djl.pytorch.jni.JniUtils

public final class JniUtils extends Object
A class containing utilities to interact with the PyTorch Engine's Java Native Interface (JNI) layer.
  • Method Details

    • isGradMode

      public static boolean isGradMode()
    • setGradMode

      public static void setGradMode(boolean enable)
    • getNumInteropThreads

      public static int getNumInteropThreads()
    • getNumThreads

      public static int getNumThreads()
    • setNumInteropThreads

      public static void setNumInteropThreads(int threads)
    • setNumThreads

      public static void setNumThreads(int threads)
    • setBenchmarkCuDNN

      public static void setBenchmarkCuDNN(boolean enable)
    • getFeatures

      public static Set<String> getFeatures()
    • setSeed

      public static void setSeed(long seed)
    • startProfile

      public static void startProfile(boolean useCuda, boolean recordShape, boolean profileMemory)
      Calls this method to start profile the area you are interested in.

      Example usage

            JniUtils.startProfile(false, true, true);
            Predictor.predict(img);
            JniUtils.stopProfile(outputFile)
       
      Parameters:
      useCuda - Enables timing of CUDA events as well using the cudaEvent API.
      recordShape - If shapes recording is set, information about input dimensions will be collected
      profileMemory - Whether to report memory usage
    • stopProfile

      public static void stopProfile(String outputFile)
    • createNdFromByteBuffer

      public static PtNDArray createNdFromByteBuffer(PtNDManager manager, ByteBuffer data, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.ndarray.types.SparseFormat fmt, ai.djl.Device device)
    • emptyCudaCache

      public static void emptyCudaCache()
    • createEmptyNdArray

      public static PtNDArray createEmptyNdArray(PtNDManager manager, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
    • createZerosNdArray

      public static PtNDArray createZerosNdArray(PtNDManager manager, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
    • createOnesNdArray

      public static PtNDArray createOnesNdArray(PtNDManager manager, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
    • full

      public static PtNDArray full(PtNDManager manager, ai.djl.ndarray.types.Shape shape, double fillValue, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
    • zerosLike

      public static PtNDArray zerosLike(PtNDArray array, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
    • onesLike

      public static PtNDArray onesLike(PtNDArray array, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
    • arange

      public static PtNDArray arange(PtNDManager manager, float start, float stop, float step, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
    • linspace

      public static PtNDArray linspace(PtNDManager manager, float start, float stop, int step, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
    • createSparseCoo

      public static PtNDArray createSparseCoo(PtNDArray indices, PtNDArray values, ai.djl.ndarray.types.Shape shape)
    • to

      public static PtNDArray to(PtNDArray ndArray, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
    • toSparse

      public static PtNDArray toSparse(PtNDArray ndArray)
    • toDense

      public static PtNDArray toDense(PtNDArray ndArray)
    • broadcast

      public static PtNDArray broadcast(PtNDArray ndArray, ai.djl.ndarray.types.Shape shape)
    • slice

      public static PtNDArray slice(PtNDArray ndArray, long dim, long start, long stop, long step)
    • index

      public static PtNDArray index(PtNDArray ndArray, long[] minIndices, long[] maxIndices, long[] stepIndices, PtNDManager manager)
    • indexAdv

      public static PtNDArray indexAdv(PtNDArray ndArray, ai.djl.ndarray.index.NDIndex index, PtNDManager manager)
    • indexAdvPut

      public static void indexAdvPut(PtNDArray ndArray, ai.djl.ndarray.index.NDIndex index, PtNDArray data)
    • indexSet

      public static void indexSet(PtNDArray ndArray, PtNDArray value, long[] minIndices, long[] maxIndices, long[] stepIndices)
    • set

      public static void set(PtNDArray self, ByteBuffer data)
    • gather

      public static PtNDArray gather(PtNDArray ndArray, PtNDArray index, long dim)
    • take

      public static PtNDArray take(PtNDArray ndArray, PtNDArray index, PtNDManager manager)
    • put

      public static PtNDArray put(PtNDArray ndArray, PtNDArray index, PtNDArray value)
    • scatter

      public static PtNDArray scatter(PtNDArray ndArray, PtNDArray index, PtNDArray value, int axis)
    • pick

      public static PtNDArray pick(PtNDArray ndArray, PtNDArray index, long dim)
    • where

      public static PtNDArray where(PtNDArray condition, PtNDArray self, PtNDArray other)
    • booleanMask

      public static PtNDArray booleanMask(PtNDArray ndArray, PtNDArray indicesNd)
    • booleanMaskSet

      public static void booleanMaskSet(PtNDArray ndArray, PtNDArray value, PtNDArray indicesNd)
    • getItem

      public static PtNDArray getItem(PtNDArray ndArray, long[] indices, PtNDManager manager)
    • clone

      public static PtNDArray clone(PtNDArray ndArray)
    • reshape

      public static PtNDArray reshape(PtNDArray ndArray, long[] shape)
    • stack

      public static PtNDArray stack(PtNDArray[] arrays, int dim)
    • cat

      public static PtNDArray cat(PtNDArray[] arrays, long dim)
    • tile

      public static PtNDArray tile(PtNDArray ndArray, long[] repeats)
    • repeat

      public static PtNDArray repeat(PtNDArray ndArray, long repeat, long dim)
    • softmax

      public static PtNDArray softmax(PtNDArray ndArray, long dim, ai.djl.ndarray.types.DataType dTpe)
    • logSoftmax

      public static PtNDArray logSoftmax(PtNDArray ndArray, long dim, ai.djl.ndarray.types.DataType dTpe)
    • argMax

      public static PtNDArray argMax(PtNDArray ndArray)
    • argMax

      public static PtNDArray argMax(PtNDArray ndArray, long dim, boolean keepDim)
    • topK

      public static ai.djl.ndarray.NDList topK(PtNDArray ndArray, long k, long axis, boolean largest, boolean sorted)
    • argMin

      public static PtNDArray argMin(PtNDArray ndArray)
    • argMin

      public static PtNDArray argMin(PtNDArray ndArray, long dim, boolean keepDim)
    • argSort

      public static PtNDArray argSort(PtNDArray ndArray, long dim, boolean keepDim)
    • sort

      public static PtNDArray sort(PtNDArray ndArray, long dim, boolean descending)
    • permute

      public static PtNDArray permute(PtNDArray ndArray, long[] dims)
    • flip

      public static PtNDArray flip(PtNDArray ndArray, long[] dims)
    • transpose

      public static PtNDArray transpose(PtNDArray ndArray, long dim1, long dim2)
    • contentEqual

      public static boolean contentEqual(PtNDArray ndArray1, PtNDArray ndArray2)
    • add

      public static PtNDArray add(PtNDArray ndArray1, PtNDArray ndArray2)
    • addi

      public static void addi(PtNDArray ndArray1, PtNDArray ndArray2)
    • sub

      public static PtNDArray sub(PtNDArray ndArray1, PtNDArray ndArray2)
    • subi

      public static void subi(PtNDArray ndArray1, PtNDArray ndArray2)
    • mul

      public static PtNDArray mul(PtNDArray ndArray1, PtNDArray ndArray2)
    • muli

      public static void muli(PtNDArray ndArray1, PtNDArray ndArray2)
    • div

      public static PtNDArray div(PtNDArray ndArray1, PtNDArray ndArray2)
    • divi

      public static void divi(PtNDArray ndArray1, PtNDArray ndArray2)
    • remainder

      public static PtNDArray remainder(PtNDArray ndArray1, PtNDArray ndArray2)
    • remainderi

      public static void remainderi(PtNDArray ndArray1, PtNDArray ndArray2)
    • pow

      public static PtNDArray pow(PtNDArray ndArray1, PtNDArray ndArray2)
    • powi

      public static void powi(PtNDArray ndArray1, PtNDArray ndArray2)
    • sign

      public static PtNDArray sign(PtNDArray ndArray)
    • signi

      public static void signi(PtNDArray ndArray)
    • logicalAnd

      public static PtNDArray logicalAnd(PtNDArray ndArray1, PtNDArray ndArray2)
    • logicalOr

      public static PtNDArray logicalOr(PtNDArray ndArray1, PtNDArray ndArray2)
    • logicalXor

      public static PtNDArray logicalXor(PtNDArray ndArray1, PtNDArray ndArray2)
    • logicalNot

      public static PtNDArray logicalNot(PtNDArray ndArray)
    • matmul

      public static PtNDArray matmul(PtNDArray ndArray1, PtNDArray ndArray2)
    • bmm

      public static PtNDArray bmm(PtNDArray ndArray1, PtNDArray ndArray2)
    • xlogy

      public static PtNDArray xlogy(PtNDArray ndArray1, PtNDArray ndArray2)
    • dot

      public static PtNDArray dot(PtNDArray ndArray1, PtNDArray ndArray2)
    • max

      public static PtNDArray max(PtNDArray ndArray1, PtNDArray ndArray2)
    • max

      public static PtNDArray max(PtNDArray ndArray)
    • max

      public static PtNDArray max(PtNDArray ndArray, long dim, boolean keepDim)
    • min

      public static PtNDArray min(PtNDArray ndArray1, PtNDArray ndArray2)
    • min

      public static PtNDArray min(PtNDArray ndArray)
    • min

      public static PtNDArray min(PtNDArray ndArray, long dim, boolean keepDim)
    • median

      public static ai.djl.ndarray.NDList median(PtNDArray ndArray, long dim, boolean keepDim)
    • mean

      public static PtNDArray mean(PtNDArray ndArray)
    • mean

      public static PtNDArray mean(PtNDArray ndArray, long dim, boolean keepDim)
    • rot90

      public static PtNDArray rot90(PtNDArray ndArray, int times, int[] axes)
    • sum

      public static PtNDArray sum(PtNDArray ndArray)
    • sum

      public static PtNDArray sum(PtNDArray ndArray, long[] dims, boolean keepDim)
    • cumProd

      public static PtNDArray cumProd(PtNDArray ndArray, long dim, ai.djl.ndarray.types.DataType dataType)
    • prod

      public static PtNDArray prod(PtNDArray ndArray)
    • prod

      public static PtNDArray prod(PtNDArray ndArray, long dim, boolean keepDim)
    • cumSum

      public static PtNDArray cumSum(PtNDArray ndArray, long dim)
    • oneHot

      public static PtNDArray oneHot(PtNDArray ndArray, int depth, ai.djl.ndarray.types.DataType dataType)
    • split

      public static ai.djl.ndarray.NDList split(PtNDArray ndArray, long size, long axis)
    • split

      public static ai.djl.ndarray.NDList split(PtNDArray ndArray, long[] indices, long axis)
    • squeeze

      public static PtNDArray squeeze(PtNDArray ndArray)
    • squeeze

      public static PtNDArray squeeze(PtNDArray ndArray, long dim)
    • unsqueeze

      public static PtNDArray unsqueeze(PtNDArray ndArray, long dim)
    • unique

      public static ai.djl.ndarray.NDList unique(PtNDArray ndArray, Integer dim, boolean sorted, boolean returnInverse, boolean returnCounts)
    • flatten

      public static PtNDArray flatten(PtNDArray ndArray, long startDim, long endDim)
    • fft

      public static PtNDArray fft(PtNDArray ndArray, long length, long axis)
    • stft

      public static PtNDArray stft(PtNDArray ndArray, long nFft, long hopLength, PtNDArray window, boolean center, boolean normalize, boolean returnComplex)
    • fft2

      public static PtNDArray fft2(PtNDArray ndArray, long[] sizes, long[] axes)
    • ifft2

      public static PtNDArray ifft2(PtNDArray ndArray, long[] sizes, long[] axes)
    • real

      public static PtNDArray real(PtNDArray ndArray)
    • complex

      public static PtNDArray complex(PtNDArray ndArray)
    • abs

      public static PtNDArray abs(PtNDArray ndArray)
    • square

      public static PtNDArray square(PtNDArray ndArray)
    • floor

      public static PtNDArray floor(PtNDArray ndArray)
    • ceil

      public static PtNDArray ceil(PtNDArray ndArray)
    • round

      public static PtNDArray round(PtNDArray ndArray)
    • trunc

      public static PtNDArray trunc(PtNDArray ndArray)
    • clip

      public static PtNDArray clip(PtNDArray ndArray, Number min, Number max)
    • exp

      public static PtNDArray exp(PtNDArray ndArray)
    • log

      public static PtNDArray log(PtNDArray ndArray)
    • log10

      public static PtNDArray log10(PtNDArray ndArray)
    • log2

      public static PtNDArray log2(PtNDArray ndArray)
    • sin

      public static PtNDArray sin(PtNDArray ndArray)
    • cos

      public static PtNDArray cos(PtNDArray ndArray)
    • tan

      public static PtNDArray tan(PtNDArray ndArray)
    • asin

      public static PtNDArray asin(PtNDArray ndArray)
    • acos

      public static PtNDArray acos(PtNDArray ndArray)
    • atan

      public static PtNDArray atan(PtNDArray ndArray)
    • atan2

      public static PtNDArray atan2(PtNDArray self, PtNDArray other)
    • sqrt

      public static PtNDArray sqrt(PtNDArray ndArray)
    • sinh

      public static PtNDArray sinh(PtNDArray ndArray)
    • cosh

      public static PtNDArray cosh(PtNDArray ndArray)
    • tanh

      public static PtNDArray tanh(PtNDArray ndArray)
    • sigmoid

      public static PtNDArray sigmoid(PtNDArray ndArray)
    • all

      public static PtNDArray all(PtNDArray ndArray)
    • any

      public static PtNDArray any(PtNDArray ndArray)
    • none

      public static PtNDArray none(PtNDArray ndArray)
    • eq

      public static PtNDArray eq(PtNDArray self, PtNDArray other)
    • neq

      public static PtNDArray neq(PtNDArray self, PtNDArray other)
    • gt

      public static PtNDArray gt(PtNDArray self, PtNDArray other)
    • gte

      public static PtNDArray gte(PtNDArray self, PtNDArray other)
    • lt

      public static PtNDArray lt(PtNDArray self, PtNDArray other)
    • lte

      public static PtNDArray lte(PtNDArray self, PtNDArray other)
    • neg

      public static PtNDArray neg(PtNDArray ndArray)
    • negi

      public static void negi(PtNDArray ndArray)
    • isNaN

      public static PtNDArray isNaN(PtNDArray ndArray)
    • isInf

      public static PtNDArray isInf(PtNDArray ndArray)
    • randint

      public static PtNDArray randint(PtNDManager manager, long low, long high, ai.djl.ndarray.types.Shape size, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
    • randperm

      public static PtNDArray randperm(PtNDManager manager, long n, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
    • normal

      public static PtNDArray normal(PtNDManager manager, double mean, double std, ai.djl.ndarray.types.Shape size, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
    • uniform

      public static PtNDArray uniform(PtNDManager manager, double low, double high, ai.djl.ndarray.types.Shape size, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
    • eye

      public static PtNDArray eye(PtNDManager manager, int n, int m, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
    • hannWindow

      public static PtNDArray hannWindow(PtNDManager manager, long numPoints, boolean periodic, ai.djl.Device device)
    • erfinv

      public static PtNDArray erfinv(PtNDArray ndArray)
    • erf

      public static PtNDArray erf(PtNDArray ndArray)
    • inverse

      public static PtNDArray inverse(PtNDArray ndArray)
    • interpolate

      public static PtNDArray interpolate(PtNDArray ndArray, long[] size, int mode, boolean alignCorners)
    • linear

      public static PtNDArray linear(PtNDArray input, PtNDArray weight, PtNDArray bias)
    • embedding

      public static PtNDArray embedding(PtNDArray input, PtNDArray weight, boolean sparse)
    • relu

      public static PtNDArray relu(PtNDArray ndArray)
    • softPlus

      public static PtNDArray softPlus(PtNDArray ndArray)
    • softSign

      public static PtNDArray softSign(PtNDArray ndArray)
    • leakyRelu

      public static PtNDArray leakyRelu(PtNDArray ndArray, double negativeSlope)
    • elu

      public static PtNDArray elu(PtNDArray ndArray, double alpha)
    • selu

      public static PtNDArray selu(PtNDArray ndArray)
    • gelu

      public static PtNDArray gelu(PtNDArray ndArray)
    • convolution

      public static PtNDArray convolution(PtNDArray ndArray, PtNDArray weight, PtNDArray bias, ai.djl.ndarray.types.Shape stride, ai.djl.ndarray.types.Shape padding, ai.djl.ndarray.types.Shape dilation, int groups)
    • batchNorm

      public static PtNDArray batchNorm(PtNDArray ndArray, PtNDArray gamma, PtNDArray beta, PtNDArray runningMean, PtNDArray runningVar, boolean isTraining, double momentum, double eps)
    • layerNorm

      public static PtNDArray layerNorm(PtNDArray ndArray, ai.djl.ndarray.types.Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps)
    • normalize

      public static PtNDArray normalize(PtNDArray ndArray, double p, long dim, double eps)
    • dropout

      public static PtNDArray dropout(PtNDArray ndArray, double prob, boolean training)
    • rnn

      public static ai.djl.ndarray.NDList rnn(PtNDArray input, PtNDArray hx, ai.djl.ndarray.NDList params, boolean hasBiases, int numLayers, ai.djl.nn.recurrent.RNN.Activation activation, double dropRate, boolean training, boolean bidirectional, boolean batchFirst)
    • gru

      public static ai.djl.ndarray.NDList gru(PtNDArray input, PtNDArray hx, ai.djl.ndarray.NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst)
    • lstm

      public static ai.djl.ndarray.NDList lstm(PtNDArray input, ai.djl.ndarray.NDList hx, ai.djl.ndarray.NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst)
    • avgPool

      public static PtNDArray avgPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape kernelSize, ai.djl.ndarray.types.Shape stride, ai.djl.ndarray.types.Shape padding, boolean ceilMode, boolean countIncludePad)
    • maxPool

      public static PtNDArray maxPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape kernelSize, ai.djl.ndarray.types.Shape stride, ai.djl.ndarray.types.Shape padding, boolean ceilMode)
    • adaptiveMaxPool

      public static PtNDArray adaptiveMaxPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape outputSize)
    • adaptiveAvgPool

      public static PtNDArray adaptiveAvgPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape outputSize)
    • lpPool

      public static PtNDArray lpPool(PtNDArray ndArray, double normType, ai.djl.ndarray.types.Shape kernelSize, ai.djl.ndarray.types.Shape stride, boolean ceilMode)
    • getDataType

      public static ai.djl.ndarray.types.DataType getDataType(PtNDArray ndArray)
    • getDevice

      public static ai.djl.Device getDevice(PtNDArray ndArray)
    • getSparseFormat

      public static ai.djl.ndarray.types.SparseFormat getSparseFormat(PtNDArray ndArray)
    • getShape

      public static ai.djl.ndarray.types.Shape getShape(PtNDArray ndArray)
    • getByteBuffer

      public static ByteBuffer getByteBuffer(PtNDArray ndArray)
    • deleteNDArray

      public static void deleteNDArray(long handle)
    • requiresGrad

      public static boolean requiresGrad(PtNDArray ndArray)
    • getGradientFunctionNames

      public static String getGradientFunctionNames(PtNDArray ndArray)
    • attachGradient

      public static void attachGradient(PtNDArray ndArray, boolean requiresGrad)
    • detachGradient

      public static PtNDArray detachGradient(PtNDArray ndArray)
    • getGradient

      public static PtNDArray getGradient(PtNDArray ndArray)
    • backward

      public static void backward(PtNDArray ndArray, PtNDArray gradNd, boolean keepGraph, boolean createGraph)
    • deleteModule

      public static void deleteModule(long pointer)
    • setGraphExecutorOptimize

      public static void setGraphExecutorOptimize(boolean enabled)
    • loadModule

      public static PtSymbolBlock loadModule(PtNDManager manager, Path path, boolean mapLocation, String[] extraFileKeys, String[] extraFileValues, boolean trainParam)
    • loadModule

      public static PtSymbolBlock loadModule(PtNDManager manager, InputStream is, boolean mapLocation, boolean hasSize) throws IOException
      Throws:
      IOException
    • loadModuleHandle

      public static long loadModuleHandle(InputStream is, ai.djl.Device device, boolean mapLocation, boolean hasSize) throws IOException
      Throws:
      IOException
    • writeModule

      public static void writeModule(PtSymbolBlock block, OutputStream os, boolean writeSize)
    • moduleGetParams

      public static ai.djl.ndarray.NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager)
    • getMethodNames

      public static String[] getMethodNames(PtSymbolBlock block)
    • enableInferenceMode

      public static void enableInferenceMode(PtSymbolBlock block)
    • enableTrainingMode

      public static void enableTrainingMode(PtSymbolBlock block)
    • zeroGrad

      public static void zeroGrad(PtNDArray weight)
    • adamUpdate

      public static void adamUpdate(PtNDArray weight, PtNDArray grad, PtNDArray mean, PtNDArray variance, float lr, float learningRateBiasCorrection, float wd, float rescaleGrad, float clipGrad, float beta1, float beta2, float eps, boolean adamw)
    • sgdUpdate

      public static void sgdUpdate(PtNDArray weight, PtNDArray grad, PtNDArray state, float lr, float wd, float rescaleGrad, float clipGrad, float momentum)
    • getLayout

      public static int getLayout(PtNDArray array)
    • norm

      public static PtNDArray norm(PtNDArray ndArray, int ord, int[] axes, boolean keepDims)
    • nonZeros

      public static PtNDArray nonZeros(PtNDArray ndArray)