public class TensorMultiplicationVertex<T extends java.lang.Number,TENSOR extends NumberTensor<T,TENSOR>,VERTEX extends NumberTensorVertex<T,TENSOR,VERTEX>> extends BinaryTensorOpVertex<T,TENSOR,VERTEX> implements NonProbabilisticVertex<TENSOR,VERTEX>, Differentiable
left, LEFT_NAME, right, RIGHT_NAME
SCALAR_SHAPE
Constructor and Description |
---|
TensorMultiplicationVertex(TensorVertex<T,TENSOR,VERTEX> left,
TensorVertex<T,TENSOR,VERTEX> right,
int[] dimsLeft,
int[] dimsRight)
Tensor multiplies one vertex by another.
|
Modifier and Type | Method and Description |
---|---|
ForwardModePartialDerivative |
forwardModeAutoDifferentiation(java.util.Map<Vertex,ForwardModePartialDerivative> derivativeOfParentsWithRespectToInput) |
int[] |
getDimsLeft() |
int[] |
getDimsRight() |
protected TENSOR |
op(TENSOR l,
TENSOR r) |
java.util.Map<Vertex,ReverseModePartialDerivative> |
reverseModeAutoDifferentiation(ReverseModePartialDerivative derivativeOfOutputWithRespectToSelf) |
calculate, getLeft, getRight, ofType, wrap
addChild, addParent, addParents, equals, eval, getChildren, getConnectedGraph, getDegree, getId, getIndentation, getLabel, getLength, getObservedValue, getParents, getRank, getReference, getShape, getState, getStride, getValue, hashCode, hasValue, isDifferentiable, isObserved, isProbabilistic, lazyEval, observe, observeOwnValue, print, print, removeLabel, setAndCascade, setLabel, setLabel, setParents, setParents, setState, setValue, toString, unobserve
clone, finalize, getClass, notify, notifyAll, wait, wait, wait
addChild, addParent, addParents, eval, getChildren, getConnectedGraph, getDegree, getId, getIndentation, getLabel, getLength, getObservedValue, getParents, getRank, getReference, getShape, getState, getStride, getValue, hasValue, isDifferentiable, isObserved, isProbabilistic, lazyEval, observe, observeOwnValue, ofType, print, print, removeLabel, setAndCascade, setLabel, setLabel, setParents, setParents, setState, setValue, unobserve
calculate, contradictsObservation
ofSelfWrtSelf, wrtSelfOfSelf
broadcast, diag, diagPart, elementwiseEquals, elementwiseEquals, fillTriangular, get, notEqualTo, notEqualTo, permute, reshape, slice, slice, take, trianglePart, triLower, triUpper, where
expandDims, getLength, getRank, getShape, getStride, isLengthOne, isMatrix, isScalar, isVector, moveAxis, slice, sliceAlongDimension, squeeze, swapAxis, transpose
public TensorMultiplicationVertex(TensorVertex<T,TENSOR,VERTEX> left, TensorVertex<T,TENSOR,VERTEX> right, int[] dimsLeft, int[] dimsRight)
left
- the left vertex for operandright
- the right vertex for operanddimsLeft
- The dimensions of the left for multiplying. The left shape at these dimensions must align with the
shape of the corresponding right vertex at its specified dimensions.dimsRight
- The dimensions of the right for multiplying. The right shape at these dimensions must align with the
shape of the corresponding left vertex at its specified dimensions.protected TENSOR op(TENSOR l, TENSOR r)
op
in class BinaryTensorOpVertex<T extends java.lang.Number,TENSOR extends NumberTensor<T,TENSOR>,VERTEX extends NumberTensorVertex<T,TENSOR,VERTEX>>
public java.util.Map<Vertex,ReverseModePartialDerivative> reverseModeAutoDifferentiation(ReverseModePartialDerivative derivativeOfOutputWithRespectToSelf)
reverseModeAutoDifferentiation
in interface Differentiable
public ForwardModePartialDerivative forwardModeAutoDifferentiation(java.util.Map<Vertex,ForwardModePartialDerivative> derivativeOfParentsWithRespectToInput)
forwardModeAutoDifferentiation
in interface Differentiable
public int[] getDimsLeft()
public int[] getDimsRight()