public class Sgd extends Object implements ComputeFunction<CfLongId,org.jblas.FloatMatrix,Float,FloatMatrixMessage>
Modifier and Type | Class and Description |
---|---|
class |
Sgd.InitItemsComputation
This computation class is used to initialize the factors of the item nodes
in the second superstep.
|
class |
Sgd.InitUsersComputation
This computation class is used to initialize the factors of the user nodes
in the very first superstep, and send the first updates to the item nodes.
|
ComputeFunction.Aggregators, ComputeFunction.Callback<K,VV,EV,Message>, ComputeFunction.InitCallback, ComputeFunction.MasterCallback, ComputeFunction.ReadAggregators, ComputeFunction.ReadWriteAggregators
Modifier and Type | Field and Description |
---|---|
static String |
GAMMA
Keyword for parameter setting the learning rate GAMMA.
|
static float |
GAMMA_DEFAULT
Default value for GAMMA.
|
static String |
ITERATIONS
Keyword for parameter setting the number of iterations.
|
static int |
ITERATIONS_DEFAULT
Default value for ITERATIONS.
|
static String |
LAMBDA
Keyword for parameter setting the regularization parameter LAMBDA.
|
static float |
LAMBDA_DEFAULT
Default value for LABDA.
|
static String |
MAX_RATING
Max rating.
|
static float |
MAX_RATING_DEFAULT
Default maximum rating
|
protected float |
maxRating |
static String |
MIN_RATING
Min rating.
|
static float |
MIN_RATING_DEFAULT
Default minimum rating
|
protected float |
minRating |
static String |
RANDOM_SEED
Random seed.
|
static Long |
RANDOM_SEED_DEFAULT
Default random seed
|
static String |
RMSE_AGGREGATOR
Aggregator used to compute the RMSE
|
static String |
RMSE_TARGET
Keyword for RMSE aggregator tolerance.
|
static float |
RMSE_TARGET_DEFAULT
Default value for parameter enabling the RMSE aggregator.
|
static String |
TOLERANCE
Keyword for parameter setting the convergence tolerance
|
static float |
TOLERANCE_DEFAULT
Default value for TOLERANCE.
|
static String |
VECTOR_SIZE
Keyword for parameter setting the Latent Vector Size.
|
static int |
VECTOR_SIZE_DEFAULT
Default value for GAMMA.
|
Constructor and Description |
---|
Sgd() |
Modifier and Type | Method and Description |
---|---|
void |
compute(int superstep,
VertexWithValue<CfLongId,org.jblas.FloatMatrix> vertex,
Iterable<FloatMatrixMessage> messages,
Iterable<EdgeWithValue<CfLongId,Float>> edges,
ComputeFunction.Callback<CfLongId,org.jblas.FloatMatrix,Float,FloatMatrixMessage> cb)
The function for computing a new vertex value or sending messages to the next superstep.
|
protected long |
getTotalNumEdges(ComputeFunction.ReadAggregators aggregators) |
void |
init(Map<String,?> configs,
ComputeFunction.InitCallback cb)
Initialize the ComputeFunction, this is the place to register aggregators.
|
void |
masterCompute(int superstep,
ComputeFunction.MasterCallback cb)
A function for performing sequential computations between supersteps.
|
void |
preSuperstep(int superstep,
ComputeFunction.Aggregators aggregators)
Prepare for computation.
|
void |
superstepCompute(int superstep,
VertexWithValue<CfLongId,org.jblas.FloatMatrix> vertex,
Iterable<FloatMatrixMessage> messages,
Iterable<EdgeWithValue<CfLongId,Float>> edges,
ComputeFunction.Callback<CfLongId,org.jblas.FloatMatrix,Float,FloatMatrixMessage> cb)
Main SGD compute method.
|
protected void |
updateValue(org.jblas.FloatMatrix value,
org.jblas.FloatMatrix update,
float rating,
float minRating,
float maxRating,
float lambda,
float gamma)
Applies the SGD update logic in the provided vector.
|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
postSuperstep
public static final String RMSE_TARGET
public static final float RMSE_TARGET_DEFAULT
public static final String TOLERANCE
public static final float TOLERANCE_DEFAULT
public static final String ITERATIONS
public static final int ITERATIONS_DEFAULT
public static final String LAMBDA
public static final float LAMBDA_DEFAULT
public static final String GAMMA
public static final float GAMMA_DEFAULT
public static final String VECTOR_SIZE
public static final int VECTOR_SIZE_DEFAULT
public static final String MAX_RATING
public static final float MAX_RATING_DEFAULT
public static final String MIN_RATING
public static final float MIN_RATING_DEFAULT
public static final String RANDOM_SEED
public static final Long RANDOM_SEED_DEFAULT
public static final String RMSE_AGGREGATOR
protected float minRating
protected float maxRating
public void preSuperstep(int superstep, ComputeFunction.Aggregators aggregators)
ComputeFunction
preSuperstep
in interface ComputeFunction<CfLongId,org.jblas.FloatMatrix,Float,FloatMatrixMessage>
superstep
- the superstepaggregators
- the aggregatorspublic void superstepCompute(int superstep, VertexWithValue<CfLongId,org.jblas.FloatMatrix> vertex, Iterable<FloatMatrixMessage> messages, Iterable<EdgeWithValue<CfLongId,Float>> edges, ComputeFunction.Callback<CfLongId,org.jblas.FloatMatrix,Float,FloatMatrixMessage> cb)
superstep
- the count of the current superstepvertex
- the current vertex with its valuemessages
- a Map of the source vertex and the message sent from the previous superstepedges
- the adjacent edges with their valuescb
- a callback for setting a new vertex value or sending messages to the next superstepprotected final void updateValue(org.jblas.FloatMatrix value, org.jblas.FloatMatrix update, float rating, float minRating, float maxRating, float lambda, float gamma)
The update is performed according to the following formula:
v = v - gamma*(lambda*v + error*u)
value
- the vector to updateupdate
- the vector used to updaterating
- the ratingminRating
- the min ratingmaxRating
- the max ratinglambda
- the lambda parametergamma
- the gamma parameterpublic final void init(Map<String,?> configs, ComputeFunction.InitCallback cb)
ComputeFunction
init
in interface ComputeFunction<CfLongId,org.jblas.FloatMatrix,Float,FloatMatrixMessage>
configs
- configuration parameterscb
- a callback for registering aggregatorspublic final void masterCompute(int superstep, ComputeFunction.MasterCallback cb)
ComputeFunction
masterCompute
in interface ComputeFunction<CfLongId,org.jblas.FloatMatrix,Float,FloatMatrixMessage>
superstep
- the superstepcb
- a callback for writing to aggregators or halting the computationpublic void compute(int superstep, VertexWithValue<CfLongId,org.jblas.FloatMatrix> vertex, Iterable<FloatMatrixMessage> messages, Iterable<EdgeWithValue<CfLongId,Float>> edges, ComputeFunction.Callback<CfLongId,org.jblas.FloatMatrix,Float,FloatMatrixMessage> cb)
ComputeFunction
compute
in interface ComputeFunction<CfLongId,org.jblas.FloatMatrix,Float,FloatMatrixMessage>
superstep
- the count of the current superstepvertex
- the current vertex with its valuemessages
- a Map of the source vertex and the message sent from the previous superstepedges
- the adjacent edges with their valuescb
- a callback for setting a new vertex value or sending messages to the next superstepprotected long getTotalNumEdges(ComputeFunction.ReadAggregators aggregators)
Copyright © 2020. All rights reserved.