Class BasicGradientsAccumulator
- java.lang.Object
-
- org.deeplearning4j.optimize.solvers.accumulation.BasicGradientsAccumulator
-
- All Implemented Interfaces:
Serializable
,GradientsAccumulator
public class BasicGradientsAccumulator extends Object implements GradientsAccumulator
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected CyclicBarrier
barrier
protected List<INDArray>
candidates
protected AtomicLong
extCounter
protected AtomicLong
firstOne
protected IndexedTail
gradients
protected MessageHandler
handler
protected AtomicBoolean
hasSomething
protected char
ordering
protected AtomicLong
ownCounter
protected int
parties
protected long[]
shape
protected INDArray
storage
protected INDArray
updates
protected ReentrantReadWriteLock
updatesLock
-
Constructor Summary
Constructors Constructor Description BasicGradientsAccumulator(int parties)
Creates new GradientsAccumulator with starting threshold of 1e-3BasicGradientsAccumulator(int parties, @NonNull MessageHandler handler)
Creates new GradientsAccumulator with custom starting threshold
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description void
applyUpdate(StepFunction function, INDArray params, INDArray grad, boolean isFinalStep)
This method applies accumulated updates via given StepFunctionvoid
applyUpdate(StepFunction function, INDArray params, INDArray grad, double alpha)
This method applies accumulated updates via given StepFunctionIndexedTail
getExternalSource()
boolean
hasAnything()
This method checks if there are any (probably external) updates availablevoid
markExternalUpdates(boolean updatesAvailable)
This method allows to highlight early availability of updatesvoid
receiveUpdate(INDArray array)
This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop PLEASE NOTE: array is expected to be ready for use and match params dimensionalityvoid
reset()
This method resets all accumulated updates (if any)void
setExternalSource(IndexedTail source)
This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instancevoid
storeUpdate(INDArray array, int iterationNumber, int epochNumber)
This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workersvoid
touch()
This method does initialization of given worker wrt Thread-Device Affinity
-
-
-
Field Detail
-
handler
protected MessageHandler handler
-
gradients
protected transient IndexedTail gradients
-
storage
protected transient INDArray storage
-
updates
protected transient INDArray updates
-
ownCounter
protected transient AtomicLong ownCounter
-
extCounter
protected transient AtomicLong extCounter
-
shape
protected long[] shape
-
ordering
protected char ordering
-
parties
protected int parties
-
barrier
protected CyclicBarrier barrier
-
firstOne
protected AtomicLong firstOne
-
updatesLock
protected ReentrantReadWriteLock updatesLock
-
hasSomething
protected AtomicBoolean hasSomething
-
-
Constructor Detail
-
BasicGradientsAccumulator
public BasicGradientsAccumulator(int parties)
Creates new GradientsAccumulator with starting threshold of 1e-3
-
BasicGradientsAccumulator
public BasicGradientsAccumulator(int parties, @NonNull @NonNull MessageHandler handler)
Creates new GradientsAccumulator with custom starting threshold- Parameters:
handler
- MessageHandler instance that'll be used for communication purposes
-
-
Method Detail
-
getExternalSource
public IndexedTail getExternalSource()
- Specified by:
getExternalSource
in interfaceGradientsAccumulator
-
applyUpdate
public void applyUpdate(StepFunction function, INDArray params, INDArray grad, boolean isFinalStep)
This method applies accumulated updates via given StepFunction- Specified by:
applyUpdate
in interfaceGradientsAccumulator
- Parameters:
function
-params
-
-
markExternalUpdates
public void markExternalUpdates(boolean updatesAvailable)
Description copied from interface:GradientsAccumulator
This method allows to highlight early availability of updates- Specified by:
markExternalUpdates
in interfaceGradientsAccumulator
-
applyUpdate
public void applyUpdate(StepFunction function, INDArray params, INDArray grad, double alpha)
This method applies accumulated updates via given StepFunction- Specified by:
applyUpdate
in interfaceGradientsAccumulator
- Parameters:
function
-params
-
-
storeUpdate
public void storeUpdate(INDArray array, int iterationNumber, int epochNumber)
This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers- Specified by:
storeUpdate
in interfaceGradientsAccumulator
- Parameters:
array
-
-
receiveUpdate
public void receiveUpdate(INDArray array)
This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop PLEASE NOTE: array is expected to be ready for use and match params dimensionality- Specified by:
receiveUpdate
in interfaceGradientsAccumulator
- Parameters:
array
-
-
reset
public void reset()
This method resets all accumulated updates (if any)- Specified by:
reset
in interfaceGradientsAccumulator
-
touch
public void touch()
This method does initialization of given worker wrt Thread-Device Affinity- Specified by:
touch
in interfaceGradientsAccumulator
-
setExternalSource
public void setExternalSource(IndexedTail source)
Description copied from interface:GradientsAccumulator
This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instance- Specified by:
setExternalSource
in interfaceGradientsAccumulator
-
hasAnything
public boolean hasAnything()
Description copied from interface:GradientsAccumulator
This method checks if there are any (probably external) updates available- Specified by:
hasAnything
in interfaceGradientsAccumulator
- Returns:
-
-