Package ai.djl.pytorch.jni
Class JniUtils
- java.lang.Object
-
- ai.djl.pytorch.jni.JniUtils
-
public final class JniUtils extends java.lang.Object
A class containing utilities to interact with the PyTorch Engine's Java Native Interface (JNI) layer.
-
-
Method Summary
All Methods Static Methods Concrete Methods Modifier and Type Method Description static PtNDArray
abs(PtNDArray ndArray)
static PtNDArray
acos(PtNDArray ndArray)
static void
adamUpdate(PtNDArray weight, PtNDArray grad, PtNDArray mean, PtNDArray variance, float lr, float wd, float rescaleGrad, float clipGrad, float beta1, float beta2, float eps)
static PtNDArray
adaptiveAvgPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape outputSize)
static PtNDArray
adaptiveMaxPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape outputSize)
static PtNDArray
add(PtNDArray ndArray1, PtNDArray ndArray2)
static void
addi(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
all(PtNDArray ndArray)
static PtNDArray
any(PtNDArray ndArray)
static PtNDArray
arange(PtNDManager manager, float start, float stop, float step, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
static PtNDArray
argMax(PtNDArray ndArray)
static PtNDArray
argMax(PtNDArray ndArray, long dim, boolean keepDim)
static PtNDArray
argMin(PtNDArray ndArray)
static PtNDArray
argMin(PtNDArray ndArray, long dim, boolean keepDim)
static PtNDArray
argSort(PtNDArray ndArray, long dim, boolean keepDim)
static PtNDArray
asin(PtNDArray ndArray)
static PtNDArray
atan(PtNDArray ndArray)
static void
attachGradient(PtNDArray ndArray, boolean requiresGrad)
static PtNDArray
avgPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape kernelSize, ai.djl.ndarray.types.Shape stride, ai.djl.ndarray.types.Shape padding, boolean ceilMode, boolean countIncludePad)
static void
backward(PtNDArray ndArray, PtNDArray gradNd, boolean keepGraph, boolean createGraph)
static PtNDArray
batchNorm(PtNDArray ndArray, PtNDArray gamma, PtNDArray beta, PtNDArray runningMean, PtNDArray runningVar, boolean isTraining, double momentum, double eps)
static PtNDArray
booleanMask(PtNDArray ndArray, PtNDArray indicesNd)
static void
booleanMaskSet(PtNDArray ndArray, PtNDArray value, PtNDArray indicesNd)
static PtNDArray
broadcast(PtNDArray ndArray, ai.djl.ndarray.types.Shape shape)
static PtNDArray
cat(PtNDArray[] arrays, long dim)
static PtNDArray
ceil(PtNDArray ndArray)
static PtNDArray
clip(PtNDArray ndArray, java.lang.Number min, java.lang.Number max)
static PtNDArray
clone(PtNDArray ndArray)
static boolean
contentEqual(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
convolution(PtNDArray ndArray, PtNDArray weight, PtNDArray bias, ai.djl.ndarray.types.Shape stride, ai.djl.ndarray.types.Shape padding, ai.djl.ndarray.types.Shape dilation, int groups)
static PtNDArray
cos(PtNDArray ndArray)
static PtNDArray
cosh(PtNDArray ndArray)
static PtNDArray
createEmptyNdArray(PtNDManager manager, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
static PtNDArray
createNdFromByteBuffer(PtNDManager manager, java.nio.ByteBuffer data, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.ndarray.types.SparseFormat fmt, ai.djl.Device device)
static PtNDArray
createOnesNdArray(PtNDManager manager, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
static PtNDArray
createSparseCoo(PtNDArray indices, PtNDArray values, ai.djl.ndarray.types.Shape shape)
static PtNDArray
createZerosNdArray(PtNDManager manager, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
static PtNDArray
cumSum(PtNDArray ndArray, long dim)
static void
deleteModule(long pointer)
static void
deleteNDArray(long handle)
static PtNDArray
detachGradient(PtNDArray ndArray)
static PtNDArray
div(PtNDArray ndArray1, PtNDArray ndArray2)
static void
divi(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
dot(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
dropout(PtNDArray ndArray, double prob, boolean training)
static PtNDArray
elu(PtNDArray ndArray, double alpha)
static PtNDArray
embedding(PtNDArray input, PtNDArray weight, boolean sparse)
static void
enableInferenceMode(PtSymbolBlock block)
static void
enableTrainingMode(PtSymbolBlock block)
static PtNDArray
eq(PtNDArray self, PtNDArray other)
static PtNDArray
erfinv(PtNDArray ndArray)
static PtNDArray
exp(PtNDArray ndArray)
static PtNDArray
eye(PtNDManager manager, int n, int m, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
static PtNDArray
flatten(PtNDArray ndArray, long startDim, long endDim)
static PtNDArray
flip(PtNDArray ndArray, long[] dims)
static PtNDArray
floor(PtNDArray ndArray)
static PtNDArray
full(PtNDManager manager, ai.djl.ndarray.types.Shape shape, double fillValue, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
static PtNDArray
gather(PtNDArray ndArray, PtNDArray index, long dim)
static PtNDArray
gelu(PtNDArray ndArray)
static java.nio.ByteBuffer
getByteBuffer(PtNDArray ndArray)
static ai.djl.ndarray.types.DataType
getDataType(PtNDArray ndArray)
static ai.djl.Device
getDevice(PtNDArray ndArray)
static java.util.Set<java.lang.String>
getFeatures()
static PtNDArray
getGradient(PtNDArray ndArray)
static java.lang.String
getGradientFunctionNames(PtNDArray ndArray)
static PtNDArray
getItem(PtNDArray ndArray, long[] indices, PtNDManager manager)
static int
getLayout(PtNDArray array)
static int
getNumInteropThreads()
static int
getNumThreads()
static ai.djl.ndarray.types.Shape
getShape(PtNDArray ndArray)
static ai.djl.ndarray.types.SparseFormat
getSparseFormat(PtNDArray ndArray)
static ai.djl.ndarray.NDList
gru(PtNDArray input, PtNDArray hx, ai.djl.ndarray.NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst)
static PtNDArray
gt(PtNDArray self, PtNDArray other)
static PtNDArray
gte(PtNDArray self, PtNDArray other)
static PtNDArray
index(PtNDArray ndArray, long[] minIndices, long[] maxIndices, long[] stepIndices, PtNDManager manager)
static PtNDArray
indexAdv(PtNDArray ndArray, ai.djl.ndarray.index.NDIndex index, PtNDManager manager)
static void
indexAdvPut(PtNDArray ndArray, ai.djl.ndarray.index.NDIndex index, PtNDArray data)
static void
indexSet(PtNDArray ndArray, PtNDArray value, long[] minIndices, long[] maxIndices, long[] stepIndices)
static PtNDArray
interpolate(PtNDArray ndArray, long[] size, int mode, boolean alignCorners)
static PtNDArray
inverse(PtNDArray ndArray)
static boolean
isGradMode()
static PtNDArray
isInf(PtNDArray ndArray)
static PtNDArray
isNaN(PtNDArray ndArray)
static PtNDArray
layerNorm(PtNDArray ndArray, ai.djl.ndarray.types.Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps)
static PtNDArray
leakyRelu(PtNDArray ndArray, double negativeSlope)
static PtNDArray
linear(PtNDArray input, PtNDArray weight, PtNDArray bias)
static PtNDArray
linspace(PtNDManager manager, float start, float stop, int step, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
static PtSymbolBlock
loadModule(PtNDManager manager, java.io.InputStream is, boolean mapLocation, boolean hasSize)
static PtSymbolBlock
loadModule(PtNDManager manager, java.nio.file.Path path, boolean mapLocation, java.lang.String[] extraFileKeys, java.lang.String[] extraFileValues)
static long
loadModuleHandle(java.io.InputStream is, ai.djl.Device device, boolean mapLocation, boolean hasSize)
static PtNDArray
log(PtNDArray ndArray)
static PtNDArray
log10(PtNDArray ndArray)
static PtNDArray
log2(PtNDArray ndArray)
static PtNDArray
logicalAnd(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
logicalNot(PtNDArray ndArray)
static PtNDArray
logicalOr(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
logicalXor(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
logSoftmax(PtNDArray ndArray, long dim, ai.djl.ndarray.types.DataType dTpe)
static PtNDArray
lpPool(PtNDArray ndArray, double normType, ai.djl.ndarray.types.Shape kernelSize, ai.djl.ndarray.types.Shape stride, boolean ceilMode)
static ai.djl.ndarray.NDList
lstm(PtNDArray input, ai.djl.ndarray.NDList hx, ai.djl.ndarray.NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst)
static PtNDArray
lt(PtNDArray self, PtNDArray other)
static PtNDArray
lte(PtNDArray self, PtNDArray other)
static PtNDArray
matmul(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
max(PtNDArray ndArray)
static PtNDArray
max(PtNDArray ndArray, long dim, boolean keepDim)
static PtNDArray
max(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
maxPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape kernelSize, ai.djl.ndarray.types.Shape stride, ai.djl.ndarray.types.Shape padding, boolean ceilMode)
static PtNDArray
mean(PtNDArray ndArray)
static PtNDArray
mean(PtNDArray ndArray, long dim, boolean keepDim)
static PtNDArray
min(PtNDArray ndArray)
static PtNDArray
min(PtNDArray ndArray, long dim, boolean keepDim)
static PtNDArray
min(PtNDArray ndArray1, PtNDArray ndArray2)
static ai.djl.ndarray.NDList
moduleGetParams(PtSymbolBlock block, PtNDManager manager)
static PtNDArray
mul(PtNDArray ndArray1, PtNDArray ndArray2)
static void
muli(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
neg(PtNDArray ndArray)
static void
negi(PtNDArray ndArray)
static PtNDArray
neq(PtNDArray self, PtNDArray other)
static PtNDArray
none(PtNDArray ndArray)
static PtNDArray
nonZeros(PtNDArray ndArray)
static PtNDArray
norm(PtNDArray ndArray, int ord, int[] axes, boolean keepDims)
static PtNDArray
normal(PtNDManager manager, double mean, double std, ai.djl.ndarray.types.Shape size, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
static PtNDArray
normalize(PtNDArray ndArray, double p, long dim, double eps)
static PtNDArray
oneHot(PtNDArray ndArray, int depth, ai.djl.ndarray.types.DataType dataType)
static PtNDArray
onesLike(PtNDArray array, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
static PtNDArray
permute(PtNDArray ndArray, long[] dims)
static PtNDArray
pick(PtNDArray ndArray, PtNDArray index, long dim)
static PtNDArray
pow(PtNDArray ndArray1, PtNDArray ndArray2)
static void
powi(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
prod(PtNDArray ndArray)
static PtNDArray
prod(PtNDArray ndArray, long dim, boolean keepDim)
static PtNDArray
put(PtNDArray ndArray, PtNDArray index, PtNDArray data)
static PtNDArray
randint(PtNDManager manager, long low, long high, ai.djl.ndarray.types.Shape size, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
static PtNDArray
relu(PtNDArray ndArray)
static PtNDArray
remainder(PtNDArray ndArray1, PtNDArray ndArray2)
static void
remainderi(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
repeat(PtNDArray ndArray, long repeat, long dim)
static boolean
requiresGrad(PtNDArray ndArray)
static PtNDArray
reshape(PtNDArray ndArray, long[] shape)
static ai.djl.ndarray.NDList
rnn(PtNDArray input, PtNDArray hx, ai.djl.ndarray.NDList params, boolean hasBiases, int numLayers, ai.djl.nn.recurrent.RNN.Activation activation, double dropRate, boolean training, boolean bidirectional, boolean batchFirst)
static PtNDArray
rot90(PtNDArray ndArray, int times, int[] axes)
static PtNDArray
round(PtNDArray ndArray)
static PtNDArray
selu(PtNDArray ndArray)
static void
set(PtNDArray self, java.nio.ByteBuffer data)
static void
setBenchmarkCuDNN(boolean enable)
static void
setGradMode(boolean enable)
static void
setGraphExecutorOptimize(boolean enabled)
static void
setNumInteropThreads(int threads)
static void
setNumThreads(int threads)
static void
setSeed(long seed)
static void
sgdUpdate(PtNDArray weight, PtNDArray grad, PtNDArray state, float lr, float wd, float rescaleGrad, float clipGrad, float momentum)
static PtNDArray
sigmoid(PtNDArray ndArray)
static PtNDArray
sign(PtNDArray ndArray)
static void
signi(PtNDArray ndArray)
static PtNDArray
sin(PtNDArray ndArray)
static PtNDArray
sinh(PtNDArray ndArray)
static PtNDArray
slice(PtNDArray ndArray, long dim, long start, long stop, long step)
static PtNDArray
softmax(PtNDArray ndArray, long dim, ai.djl.ndarray.types.DataType dTpe)
static PtNDArray
softPlus(PtNDArray ndArray)
static PtNDArray
softSign(PtNDArray ndArray)
static PtNDArray
sort(PtNDArray ndArray, long dim, boolean descending)
static ai.djl.ndarray.NDList
split(PtNDArray ndArray, long[] indices, long axis)
static ai.djl.ndarray.NDList
split(PtNDArray ndArray, long size, long axis)
static PtNDArray
sqrt(PtNDArray ndArray)
static PtNDArray
square(PtNDArray ndArray)
static PtNDArray
squeeze(PtNDArray ndArray)
static PtNDArray
squeeze(PtNDArray ndArray, long dim)
static PtNDArray
stack(PtNDArray[] arrays, int dim)
static void
startProfile(boolean useCuda, boolean recordShape, boolean profileMemory)
Calls this method to start profile the area you are interested in.static void
stopProfile(java.lang.String outputFile)
static PtNDArray
sub(PtNDArray ndArray1, PtNDArray ndArray2)
static void
subi(PtNDArray ndArray1, PtNDArray ndArray2)
static PtNDArray
sum(PtNDArray ndArray)
static PtNDArray
sum(PtNDArray ndArray, long[] dims, boolean keepDim)
static PtNDArray
take(PtNDArray ndArray, PtNDArray index, PtNDManager manager)
static PtNDArray
tan(PtNDArray ndArray)
static PtNDArray
tanh(PtNDArray ndArray)
static PtNDArray
tile(PtNDArray ndArray, long[] repeats)
static PtNDArray
to(PtNDArray ndArray, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
static PtNDArray
toDense(PtNDArray ndArray)
static PtNDArray
toSparse(PtNDArray ndArray)
static PtNDArray
transpose(PtNDArray ndArray, long dim1, long dim2)
static PtNDArray
trunc(PtNDArray ndArray)
static PtNDArray
uniform(PtNDManager manager, double low, double high, ai.djl.ndarray.types.Shape size, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
static PtNDArray
unsqueeze(PtNDArray ndArray, long dim)
static PtNDArray
where(PtNDArray condition, PtNDArray self, PtNDArray other)
static void
writeModule(PtSymbolBlock block, java.io.OutputStream os, boolean writeSize)
static void
zeroGrad(PtNDArray weight)
static PtNDArray
zerosLike(PtNDArray array, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
-
-
-
Method Detail
-
isGradMode
public static boolean isGradMode()
-
setGradMode
public static void setGradMode(boolean enable)
-
getNumInteropThreads
public static int getNumInteropThreads()
-
getNumThreads
public static int getNumThreads()
-
setNumInteropThreads
public static void setNumInteropThreads(int threads)
-
setNumThreads
public static void setNumThreads(int threads)
-
setBenchmarkCuDNN
public static void setBenchmarkCuDNN(boolean enable)
-
getFeatures
public static java.util.Set<java.lang.String> getFeatures()
-
setSeed
public static void setSeed(long seed)
-
startProfile
public static void startProfile(boolean useCuda, boolean recordShape, boolean profileMemory)
Calls this method to start profile the area you are interested in.Example usage
JniUtils.startProfile(false, true, true); Predictor.predict(img); JniUtils.stopProfile(outputFile)
- Parameters:
useCuda
- Enables timing of CUDA events as well using the cudaEvent API.recordShape
- If shapes recording is set, information about input dimensions will be collectedprofileMemory
- Whether to report memory usage
-
stopProfile
public static void stopProfile(java.lang.String outputFile)
-
createNdFromByteBuffer
public static PtNDArray createNdFromByteBuffer(PtNDManager manager, java.nio.ByteBuffer data, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.ndarray.types.SparseFormat fmt, ai.djl.Device device)
-
createEmptyNdArray
public static PtNDArray createEmptyNdArray(PtNDManager manager, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
-
createZerosNdArray
public static PtNDArray createZerosNdArray(PtNDManager manager, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
-
createOnesNdArray
public static PtNDArray createOnesNdArray(PtNDManager manager, ai.djl.ndarray.types.Shape shape, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
-
full
public static PtNDArray full(PtNDManager manager, ai.djl.ndarray.types.Shape shape, double fillValue, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
-
zerosLike
public static PtNDArray zerosLike(PtNDArray array, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
-
onesLike
public static PtNDArray onesLike(PtNDArray array, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
-
arange
public static PtNDArray arange(PtNDManager manager, float start, float stop, float step, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
-
linspace
public static PtNDArray linspace(PtNDManager manager, float start, float stop, int step, ai.djl.ndarray.types.DataType dType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
-
createSparseCoo
public static PtNDArray createSparseCoo(PtNDArray indices, PtNDArray values, ai.djl.ndarray.types.Shape shape)
-
to
public static PtNDArray to(PtNDArray ndArray, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
-
index
public static PtNDArray index(PtNDArray ndArray, long[] minIndices, long[] maxIndices, long[] stepIndices, PtNDManager manager)
-
indexAdv
public static PtNDArray indexAdv(PtNDArray ndArray, ai.djl.ndarray.index.NDIndex index, PtNDManager manager)
-
indexAdvPut
public static void indexAdvPut(PtNDArray ndArray, ai.djl.ndarray.index.NDIndex index, PtNDArray data)
-
indexSet
public static void indexSet(PtNDArray ndArray, PtNDArray value, long[] minIndices, long[] maxIndices, long[] stepIndices)
-
set
public static void set(PtNDArray self, java.nio.ByteBuffer data)
-
take
public static PtNDArray take(PtNDArray ndArray, PtNDArray index, PtNDManager manager)
-
booleanMaskSet
public static void booleanMaskSet(PtNDArray ndArray, PtNDArray value, PtNDArray indicesNd)
-
getItem
public static PtNDArray getItem(PtNDArray ndArray, long[] indices, PtNDManager manager)
-
softmax
public static PtNDArray softmax(PtNDArray ndArray, long dim, ai.djl.ndarray.types.DataType dTpe)
-
logSoftmax
public static PtNDArray logSoftmax(PtNDArray ndArray, long dim, ai.djl.ndarray.types.DataType dTpe)
-
signi
public static void signi(PtNDArray ndArray)
-
oneHot
public static PtNDArray oneHot(PtNDArray ndArray, int depth, ai.djl.ndarray.types.DataType dataType)
-
split
public static ai.djl.ndarray.NDList split(PtNDArray ndArray, long size, long axis)
-
split
public static ai.djl.ndarray.NDList split(PtNDArray ndArray, long[] indices, long axis)
-
negi
public static void negi(PtNDArray ndArray)
-
randint
public static PtNDArray randint(PtNDManager manager, long low, long high, ai.djl.ndarray.types.Shape size, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
-
normal
public static PtNDArray normal(PtNDManager manager, double mean, double std, ai.djl.ndarray.types.Shape size, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
-
uniform
public static PtNDArray uniform(PtNDManager manager, double low, double high, ai.djl.ndarray.types.Shape size, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device)
-
eye
public static PtNDArray eye(PtNDManager manager, int n, int m, ai.djl.ndarray.types.DataType dataType, ai.djl.Device device, ai.djl.ndarray.types.SparseFormat fmt)
-
interpolate
public static PtNDArray interpolate(PtNDArray ndArray, long[] size, int mode, boolean alignCorners)
-
convolution
public static PtNDArray convolution(PtNDArray ndArray, PtNDArray weight, PtNDArray bias, ai.djl.ndarray.types.Shape stride, ai.djl.ndarray.types.Shape padding, ai.djl.ndarray.types.Shape dilation, int groups)
-
batchNorm
public static PtNDArray batchNorm(PtNDArray ndArray, PtNDArray gamma, PtNDArray beta, PtNDArray runningMean, PtNDArray runningVar, boolean isTraining, double momentum, double eps)
-
layerNorm
public static PtNDArray layerNorm(PtNDArray ndArray, ai.djl.ndarray.types.Shape normalizedShape, PtNDArray gamma, PtNDArray beta, double eps)
-
rnn
public static ai.djl.ndarray.NDList rnn(PtNDArray input, PtNDArray hx, ai.djl.ndarray.NDList params, boolean hasBiases, int numLayers, ai.djl.nn.recurrent.RNN.Activation activation, double dropRate, boolean training, boolean bidirectional, boolean batchFirst)
-
gru
public static ai.djl.ndarray.NDList gru(PtNDArray input, PtNDArray hx, ai.djl.ndarray.NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst)
-
lstm
public static ai.djl.ndarray.NDList lstm(PtNDArray input, ai.djl.ndarray.NDList hx, ai.djl.ndarray.NDList params, boolean hasBiases, int numLayers, double dropRate, boolean training, boolean bidirectional, boolean batchFirst)
-
avgPool
public static PtNDArray avgPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape kernelSize, ai.djl.ndarray.types.Shape stride, ai.djl.ndarray.types.Shape padding, boolean ceilMode, boolean countIncludePad)
-
maxPool
public static PtNDArray maxPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape kernelSize, ai.djl.ndarray.types.Shape stride, ai.djl.ndarray.types.Shape padding, boolean ceilMode)
-
adaptiveMaxPool
public static PtNDArray adaptiveMaxPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape outputSize)
-
adaptiveAvgPool
public static PtNDArray adaptiveAvgPool(PtNDArray ndArray, ai.djl.ndarray.types.Shape outputSize)
-
lpPool
public static PtNDArray lpPool(PtNDArray ndArray, double normType, ai.djl.ndarray.types.Shape kernelSize, ai.djl.ndarray.types.Shape stride, boolean ceilMode)
-
getDataType
public static ai.djl.ndarray.types.DataType getDataType(PtNDArray ndArray)
-
getDevice
public static ai.djl.Device getDevice(PtNDArray ndArray)
-
getSparseFormat
public static ai.djl.ndarray.types.SparseFormat getSparseFormat(PtNDArray ndArray)
-
getShape
public static ai.djl.ndarray.types.Shape getShape(PtNDArray ndArray)
-
getByteBuffer
public static java.nio.ByteBuffer getByteBuffer(PtNDArray ndArray)
-
deleteNDArray
public static void deleteNDArray(long handle)
-
requiresGrad
public static boolean requiresGrad(PtNDArray ndArray)
-
getGradientFunctionNames
public static java.lang.String getGradientFunctionNames(PtNDArray ndArray)
-
attachGradient
public static void attachGradient(PtNDArray ndArray, boolean requiresGrad)
-
backward
public static void backward(PtNDArray ndArray, PtNDArray gradNd, boolean keepGraph, boolean createGraph)
-
deleteModule
public static void deleteModule(long pointer)
-
setGraphExecutorOptimize
public static void setGraphExecutorOptimize(boolean enabled)
-
loadModule
public static PtSymbolBlock loadModule(PtNDManager manager, java.nio.file.Path path, boolean mapLocation, java.lang.String[] extraFileKeys, java.lang.String[] extraFileValues)
-
loadModule
public static PtSymbolBlock loadModule(PtNDManager manager, java.io.InputStream is, boolean mapLocation, boolean hasSize) throws java.io.IOException
- Throws:
java.io.IOException
-
loadModuleHandle
public static long loadModuleHandle(java.io.InputStream is, ai.djl.Device device, boolean mapLocation, boolean hasSize) throws java.io.IOException
- Throws:
java.io.IOException
-
writeModule
public static void writeModule(PtSymbolBlock block, java.io.OutputStream os, boolean writeSize)
-
moduleGetParams
public static ai.djl.ndarray.NDList moduleGetParams(PtSymbolBlock block, PtNDManager manager)
-
enableInferenceMode
public static void enableInferenceMode(PtSymbolBlock block)
-
enableTrainingMode
public static void enableTrainingMode(PtSymbolBlock block)
-
zeroGrad
public static void zeroGrad(PtNDArray weight)
-
adamUpdate
public static void adamUpdate(PtNDArray weight, PtNDArray grad, PtNDArray mean, PtNDArray variance, float lr, float wd, float rescaleGrad, float clipGrad, float beta1, float beta2, float eps)
-
sgdUpdate
public static void sgdUpdate(PtNDArray weight, PtNDArray grad, PtNDArray state, float lr, float wd, float rescaleGrad, float clipGrad, float momentum)
-
getLayout
public static int getLayout(PtNDArray array)
-
-