Class PtGradientCollector

java.lang.Object
ai.djl.pytorch.engine.PtGradientCollector
All Implemented Interfaces:
ai.djl.training.GradientCollector, AutoCloseable

public final class PtGradientCollector extends Object implements ai.djl.training.GradientCollector
PtGradientCollector is the PyTorch implementation of GradientCollector.
  • Constructor Details

    • PtGradientCollector

      public PtGradientCollector()
      Constructs a new PtGradientCollector instance.
  • Method Details

    • backward

      public void backward(ai.djl.ndarray.NDArray target)
      Specified by:
      backward in interface ai.djl.training.GradientCollector
    • zeroGradients

      public void zeroGradients()
      Specified by:
      zeroGradients in interface ai.djl.training.GradientCollector
    • close

      public void close()
      Specified by:
      close in interface AutoCloseable
      Specified by:
      close in interface ai.djl.training.GradientCollector