Package ai.djl.training
Interface GradientCollector
-
- All Superinterfaces:
java.lang.AutoCloseable
public interface GradientCollector extends java.lang.AutoCloseable
An interface that provides a mechanism to collect gradients during training.The
GradientCollector
should be opened with a try-with-resources. All operations performed within the try-with-resources are recorded and the variables marked. Whenbackward function
is called, gradients are collected w.r.t previously marked variables.The typical behavior is to open up a gradient collector during each batch and close it during the end of the batch. In this way, the gradient is reset between batches. If the gradient collector is left open for multiple calls to backwards, the gradients collected are accumulated and added together.
Due to limitations in most engines, the gradient collectors are global. This means that only one can be used at a time. If multiple are opened, an error will be thrown.
-
-
Method Summary
All Methods Instance Methods Abstract Methods Modifier and Type Method Description void
backward(NDArray target)
Calculate the gradient w.r.t previously marked variable (head).void
close()
void
zeroGradients()
Sets all the gradients within the engine to zero.
-
-
-
Method Detail
-
backward
void backward(NDArray target)
Calculate the gradient w.r.t previously marked variable (head).- Parameters:
target
- the target NDArray to calculate the gradient w.r.t head
-
zeroGradients
void zeroGradients()
Sets all the gradients within the engine to zero.
-
close
void close()
- Specified by:
close
in interfacejava.lang.AutoCloseable
-
-