Class EncodedGradientsAccumulator
- java.lang.Object
-
- org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator
-
- All Implemented Interfaces:
Serializable
,GradientsAccumulator
,Registerable
public class EncodedGradientsAccumulator extends Object implements GradientsAccumulator, Registerable
- See Also:
- Serialized Form
-
-
Nested Class Summary
Nested Classes Modifier and Type Class Description static class
EncodedGradientsAccumulator.Builder
-
Field Summary
-
Constructor Summary
Constructors Constructor Description EncodedGradientsAccumulator(int parties, double threshold)
EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory, int queueSize, Integer boundary, boolean encodingDebugMode)
EncodedGradientsAccumulator(int parties, ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, boolean encodingDebugMode)
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description void
applyUpdate(StepFunction function, INDArray params, INDArray updates, boolean isFinalStep)
This method applies accumulated updates via given StepFunctionvoid
applyUpdate(StepFunction function, INDArray params, INDArray updates, double alpha)
This method applies accumulated updates via given StepFunctionvoid
fallbackToSingleConsumerMode(boolean reallyFallback)
This method enables/disables bypass modeIndexedTail
getExternalSource()
static long
getOptimalBufferSize(long paramsLength, int numWorkers, int queueSize)
This method returns optimal bufferSize for a given model We know, that updates are guaranteed to have MAX size of params / 16.static long
getOptimalBufferSize(Model model, int numWorkers, int queueSize)
boolean
hasAnything()
This method checks if there are any (probably external) updates availablevoid
markExternalUpdates(boolean updatesAvailable)
This method allows to highlight early availability of updatesvoid
receiveUpdate(INDArray array)
This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loopvoid
registerConsumers(int numConsumers)
This method notifies producer about number of consumers for the current consumption cyclevoid
reset()
This method resets all accumulated updates (if any)void
setExternalSource(IndexedTail source)
This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instancevoid
storeUpdate(INDArray array, int iterationNumber, int epochNumber)
This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workersprotected void
synchronize(int consumers)
protected void
synchronize(int consumers, boolean finalLock)
void
touch()
This method does initialization of given worker wrt Thread-Device Affinity
-
-
-
Field Detail
-
DEFAULT_INITIAL_MEMORY
public static final long DEFAULT_INITIAL_MEMORY
- See Also:
- Constant Field Values
-
accumulator
protected ThreadLocal<INDArray> accumulator
-
parties
protected int parties
-
handler
protected MessageHandler handler
-
messages
protected List<BlockingQueue<INDArray>> messages
-
workspaces
protected List<MemoryWorkspace> workspaces
-
locks
protected List<ReentrantLock> locks
-
workersCounter
protected AtomicInteger workersCounter
-
index
protected ThreadLocal<Integer> index
-
initialMemory
protected long initialMemory
-
queueSize
protected int queueSize
-
boundary
protected Integer boundary
-
encodingDebugMode
protected boolean encodingDebugMode
-
externalSource
protected IndexedTail externalSource
-
isFirst
protected AtomicBoolean isFirst
-
isDone
protected AtomicBoolean isDone
-
barrier
protected AtomicInteger barrier
-
secondary
protected AtomicInteger secondary
-
registered
protected AtomicBoolean registered
-
bypassMode
protected AtomicBoolean bypassMode
-
currentConsumers
protected final AtomicInteger currentConsumers
-
throwable
protected final AtomicThrowable throwable
-
isDebug
protected boolean isDebug
-
relocatable
protected final boolean relocatable
-
updatesApplied
protected ThreadLocal<AtomicLong> updatesApplied
-
externalUpdatesAvailable
protected AtomicBoolean externalUpdatesAvailable
-
appliedConfiguration
protected WorkspaceConfiguration appliedConfiguration
-
-
Constructor Detail
-
EncodedGradientsAccumulator
public EncodedGradientsAccumulator(int parties, double threshold)
-
EncodedGradientsAccumulator
public EncodedGradientsAccumulator(int parties, ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, boolean encodingDebugMode)
-
EncodedGradientsAccumulator
public EncodedGradientsAccumulator(int parties, @NonNull @NonNull MessageHandler handler, long initialMemory, int queueSize, Integer boundary, boolean encodingDebugMode)
-
-
Method Detail
-
getOptimalBufferSize
public static long getOptimalBufferSize(long paramsLength, int numWorkers, int queueSize)
This method returns optimal bufferSize for a given model We know, that updates are guaranteed to have MAX size of params / 16. So, here we go. I.e. for model with 100m params, that's 400m of floats (or 800m of doubles) The worst case for us is bitmap encoding, that takes 2 bits to encode each gradient value so, for float in worst case we'll have (100m / 16) int elements. So, our buffer size will be 6.25m * queueSize * 4 bytes per int- Parameters:
paramsLength
-numWorkers
-queueSize
-- Returns:
-
getOptimalBufferSize
public static long getOptimalBufferSize(Model model, int numWorkers, int queueSize)
-
fallbackToSingleConsumerMode
public void fallbackToSingleConsumerMode(boolean reallyFallback)
Description copied from interface:Registerable
This method enables/disables bypass mode- Specified by:
fallbackToSingleConsumerMode
in interfaceRegisterable
-
registerConsumers
public void registerConsumers(int numConsumers)
Description copied from interface:Registerable
This method notifies producer about number of consumers for the current consumption cycle- Specified by:
registerConsumers
in interfaceRegisterable
-
getExternalSource
public IndexedTail getExternalSource()
- Specified by:
getExternalSource
in interfaceGradientsAccumulator
-
markExternalUpdates
public void markExternalUpdates(boolean updatesAvailable)
Description copied from interface:GradientsAccumulator
This method allows to highlight early availability of updates- Specified by:
markExternalUpdates
in interfaceGradientsAccumulator
-
synchronize
protected void synchronize(int consumers)
-
synchronize
protected void synchronize(int consumers, boolean finalLock)
-
applyUpdate
public void applyUpdate(StepFunction function, INDArray params, INDArray updates, boolean isFinalStep)
This method applies accumulated updates via given StepFunction- Specified by:
applyUpdate
in interfaceGradientsAccumulator
- Parameters:
function
-params
-
-
applyUpdate
public void applyUpdate(StepFunction function, INDArray params, INDArray updates, double alpha)
This method applies accumulated updates via given StepFunction- Specified by:
applyUpdate
in interfaceGradientsAccumulator
- Parameters:
function
-params
-alpha
-
-
setExternalSource
public void setExternalSource(IndexedTail source)
This method allows to pass external updates to accumulator, they will be populated across all workers using this GradientsAccumulator instance- Specified by:
setExternalSource
in interfaceGradientsAccumulator
- Parameters:
source
-
-
touch
public void touch()
This method does initialization of given worker wrt Thread-Device Affinity- Specified by:
touch
in interfaceGradientsAccumulator
-
storeUpdate
public void storeUpdate(INDArray array, int iterationNumber, int epochNumber)
This method accepts updates suitable for StepFunction, and accumulates/propagates it across all workers- Specified by:
storeUpdate
in interfaceGradientsAccumulator
- Parameters:
array
-
-
receiveUpdate
public void receiveUpdate(INDArray array)
This method accepts updates suitable for StepFunction and puts them to the queue, which is used in backpropagation loopPLEASE NOTE: array is expected to be ready for use and match params dimensionality
- Specified by:
receiveUpdate
in interfaceGradientsAccumulator
- Parameters:
array
-
-
reset
public void reset()
This method resets all accumulated updates (if any)- Specified by:
reset
in interfaceGradientsAccumulator
-
hasAnything
public boolean hasAnything()
Description copied from interface:GradientsAccumulator
This method checks if there are any (probably external) updates available- Specified by:
hasAnything
in interfaceGradientsAccumulator
- Returns:
-
-