public class InMemoryGraphLookupTable extends Object implements GraphVectorLookupTable
| Modifier and Type | Field and Description |
|---|---|
protected double[] |
expTable |
protected double |
learningRate |
protected static double |
MAX_EXP |
protected int |
nVertices |
protected org.nd4j.linalg.api.ndarray.INDArray |
outWeights |
protected BinaryTree |
tree |
protected int |
vectorSize |
protected org.nd4j.linalg.api.ndarray.INDArray |
vertexVectors |
| Constructor and Description |
|---|
InMemoryGraphLookupTable(int nVertices,
int vectorSize,
BinaryTree tree,
double learningRate) |
| Modifier and Type | Method and Description |
|---|---|
double |
calculateProb(int first,
int second)
Calculate the probability of the second vertex given the first vertex
i.e., P(v_second | v_first)
|
double |
calculateScore(int first,
int second)
Calculate score.
|
org.nd4j.linalg.api.ndarray.INDArray |
getInnerNodeVector(int innerNode) |
int |
getNumVertices()
Returns the number of vertices in the graph
|
org.nd4j.linalg.api.ndarray.INDArray |
getOutWeights() |
BinaryTree |
getTree() |
org.nd4j.linalg.api.ndarray.INDArray |
getVector(int idx)
Get the vector for the vertex with index idx
|
org.nd4j.linalg.api.ndarray.INDArray |
getVertexVectors() |
void |
iterate(int first,
int second)
Conduct learning given a pair of vertices (in and out)
|
void |
resetWeights()
Reset (randomize) the weights.
|
void |
setLearningRate(double learningRate)
Set the learning rate
|
void |
setVertexVectors(org.nd4j.linalg.api.ndarray.INDArray vertexVectors) |
org.nd4j.linalg.api.ndarray.INDArray[][] |
vectorsAndGradients(int first,
int second)
Returns vertex vector and vector gradients, plus inner node vectors and inner node gradients
Specifically, out[0] are vectors, out[1] are gradients for the corresponding vectors out[0][0] is vector for first vertex; out[0][1] is gradient for this vertex vector out[0][i] (i>0) is the inner node vector along path to second vertex; out[1][i] is gradient for inner node vertex This design is used primarily to aid in testing (numerical gradient checks) |
int |
vectorSize()
The size of the vector representations
|
protected int nVertices
protected int vectorSize
protected BinaryTree tree
protected org.nd4j.linalg.api.ndarray.INDArray vertexVectors
protected org.nd4j.linalg.api.ndarray.INDArray outWeights
protected double learningRate
protected double[] expTable
protected static double MAX_EXP
public InMemoryGraphLookupTable(int nVertices,
int vectorSize,
BinaryTree tree,
double learningRate)
public org.nd4j.linalg.api.ndarray.INDArray getVertexVectors()
public org.nd4j.linalg.api.ndarray.INDArray getOutWeights()
public int vectorSize()
GraphVectorLookupTablevectorSize in interface GraphVectorLookupTablepublic void resetWeights()
GraphVectorLookupTableresetWeights in interface GraphVectorLookupTablepublic void iterate(int first,
int second)
GraphVectorLookupTableiterate in interface GraphVectorLookupTablepublic org.nd4j.linalg.api.ndarray.INDArray[][] vectorsAndGradients(int first,
int second)
first - first (input) vertex indexsecond - second (output) vertex indexpublic double calculateProb(int first,
int second)
first - index of the first vertexsecond - index of the second vertexpublic double calculateScore(int first,
int second)
public BinaryTree getTree()
public org.nd4j.linalg.api.ndarray.INDArray getInnerNodeVector(int innerNode)
public org.nd4j.linalg.api.ndarray.INDArray getVector(int idx)
GraphVectorLookupTablegetVector in interface GraphVectorLookupTablepublic void setLearningRate(double learningRate)
GraphVectorLookupTablesetLearningRate in interface GraphVectorLookupTablepublic int getNumVertices()
GraphVectorLookupTablegetNumVertices in interface GraphVectorLookupTablepublic void setVertexVectors(org.nd4j.linalg.api.ndarray.INDArray vertexVectors)
Copyright © 2016. All Rights Reserved.