Package ai.djl.training
Class ParameterStore
- java.lang.Object
-
- ai.djl.training.ParameterStore
-
public class ParameterStore extends java.lang.Object
TheParameterStore
contains a map from a parameter to the mirrors of it on other devices.
-
-
Constructor Summary
Constructors Constructor Description ParameterStore()
Constructs a newParameterStore
instance.ParameterStore(NDManager manager, boolean copy)
Constructs an emptyParameterStore
.
-
Method Summary
All Methods Instance Methods Concrete Methods Modifier and Type Method Description NDManager
getManager()
Get theNDManager
associated withParameterStore
.NDArray
getValue(Parameter parameter, Device device, boolean training)
Returns the value of a mirrored parameter on a device.void
setParameterServer(ParameterServer parameterServer, Device[] devices)
Sets the parameterServer used to apply updates to the parameters.void
sync()
Synchronizes the values on all mirrors with the main parameter.void
updateAllParameters()
Updates all the mirrored parameters.
-
-
-
Constructor Detail
-
ParameterStore
public ParameterStore()
Constructs a newParameterStore
instance.
-
ParameterStore
public ParameterStore(NDManager manager, boolean copy)
Constructs an emptyParameterStore
.- Parameters:
manager
- the manager to attach mirrored parameters tocopy
- whether to always copy even for the same device as the original parameter
-
-
Method Detail
-
setParameterServer
public void setParameterServer(ParameterServer parameterServer, Device[] devices)
Sets the parameterServer used to apply updates to the parameters.- Parameters:
parameterServer
- the parameterServerdevices
- the devices to create mirrored parameters on
-
updateAllParameters
public void updateAllParameters()
Updates all the mirrored parameters.
-
getValue
public NDArray getValue(Parameter parameter, Device device, boolean training)
Returns the value of a mirrored parameter on a device.- Parameters:
parameter
- the parameter to get the value fordevice
- the device to get the mirror fromtraining
- true for a training forward pass- Returns:
- the value of the mirrored parameter on the device
-
getManager
public NDManager getManager()
Get theNDManager
associated withParameterStore
.- Returns:
- the
NDManager
-
sync
public void sync()
Synchronizes the values on all mirrors with the main parameter.
-
-