public class BasicGradientsAccumulator extends Object implements GradientsAccumulator
| Modifier and Type | Field and Description |
|---|---|
protected CyclicBarrier |
barrier |
protected List<INDArray> |
candidates |
protected AtomicLong |
extCounter |
protected AtomicLong |
firstOne |
protected IndexedTail |
gradients |
protected MessageHandler |
handler |
protected AtomicBoolean |
hasSomething |
protected char |
ordering |
protected AtomicLong |
ownCounter |
protected int |
parties |
protected long[] |
shape |
protected INDArray |
storage |
protected INDArray |
updates |
protected ReentrantReadWriteLock |
updatesLock |
| Constructor and Description |
|---|
BasicGradientsAccumulator(int parties)
Creates new GradientsAccumulator with starting threshold of 1e-3
|
BasicGradientsAccumulator(int parties,
@NonNull MessageHandler handler)
Creates new GradientsAccumulator with custom starting threshold
|
| Modifier and Type | Method and Description |
|---|---|
void |
applyUpdate(StepFunction function,
INDArray params,
INDArray grad,
boolean isFinalStep)
This method applies accumulated updates via given StepFunction
|
void |
applyUpdate(StepFunction function,
INDArray params,
INDArray grad,
double alpha)
This method applies accumulated updates via given StepFunction
|
IndexedTail |
getExternalSource() |
boolean |
hasAnything()
This method checks if there are any (probably external) updates available
|
void |
markExternalUpdates(boolean updatesAvailable)
This method allows to highlight early availability of updates
|
void |
receiveUpdate(INDArray array)
This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loop
PLEASE NOTE: array is expected to be ready for use and match params dimensionality
|
void |
reset()
This method resets all accumulated updates (if any)
|
void |
setExternalSource(IndexedTail source)
This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instance
|
void |
storeUpdate(INDArray array,
int iterationNumber,
int epochNumber)
This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers
|
void |
touch()
This method does initialization of given worker wrt Thread-Device Affinity
|
protected MessageHandler handler
protected transient IndexedTail gradients
protected transient INDArray storage
protected transient INDArray updates
protected transient AtomicLong ownCounter
protected transient AtomicLong extCounter
protected long[] shape
protected char ordering
protected int parties
protected CyclicBarrier barrier
protected AtomicLong firstOne
protected ReentrantReadWriteLock updatesLock
protected AtomicBoolean hasSomething
public BasicGradientsAccumulator(int parties)
public BasicGradientsAccumulator(int parties,
@NonNull
@NonNull MessageHandler handler)
handler - MessageHandler instance that'll be used for communication purposespublic IndexedTail getExternalSource()
getExternalSource in interface GradientsAccumulatorpublic void applyUpdate(StepFunction function, INDArray params, INDArray grad, boolean isFinalStep)
applyUpdate in interface GradientsAccumulatorfunction - params - public void markExternalUpdates(boolean updatesAvailable)
GradientsAccumulatormarkExternalUpdates in interface GradientsAccumulatorpublic void applyUpdate(StepFunction function, INDArray params, INDArray grad, double alpha)
applyUpdate in interface GradientsAccumulatorfunction - params - public void storeUpdate(INDArray array, int iterationNumber, int epochNumber)
storeUpdate in interface GradientsAccumulatorarray - public void receiveUpdate(INDArray array)
receiveUpdate in interface GradientsAccumulatorarray - public void reset()
reset in interface GradientsAccumulatorpublic void touch()
touch in interface GradientsAccumulatorpublic void setExternalSource(IndexedTail source)
GradientsAccumulatorsetExternalSource in interface GradientsAccumulatorpublic boolean hasAnything()
GradientsAccumulatorhasAnything in interface GradientsAccumulatorCopyright © 2020. All rights reserved.