public class LSTMBlock extends DynamicCustomOp
See also LSTMBlockCell
- lstmBlockCell op is used internally at C++ level for computation.
Input arrays:
0: max sequence length; long/int64 scalar
1: input [seqLength, bS, inSize] at time t
2: previous/initial cell state [bS, numUnits]
3: previous/initial output [bS, numUnits]
4: Weights - concatenated (input-to-hidden, hidden-to-hidden weights) weights, [(inSize+numUnits), 4*numUnits]
5: weights - cell peephole (t-1) connections to input modulation gate, [numUnits]
6: weights - cell peephole (t-1) connections to forget gate, [numUnits]
7: weights - cell peephole (t) connections to output gate, [numUnits]
8: biases, shape [4*numUnits]
Input integer arguments: set via LSTMConfiguration
0: if not zero, provide peephole connections
1: Data format - 0=TNS=[seqLen,mb,size]; 1=NST=[mb,size,seqLen]; 2=NTS=[mb,seqLen,size]
Input float arguments: set via LSTMConfiguration
0: the bias added to forget gates in order to reduce the scale of forgetting in the beginning of the training
1: clipping value for cell state, if it is not equal to zero, then cell state is clipped
Output arrays:
0: i - Input modulation gate activations, rank 3, shape as per dataFormat
1: c (cs) - Cell state (pre tanh), rank 3, shape as per dataFormat
2: f - Output - forget gate activations, rank 3, shape as per dataFormat
3: o - Output - output gate activations, rank 3, shape as per dataFormat
4: z (ci) - Output - block input, rank 3, shape as per dataFormat
5: h (co) - Cell state, post tanh, rank 3, shape as per dataFormat
6: y (h) - Current cell output, rank 3, shape as per dataFormat
DynamicCustomOp.DynamicCustomOpsBuilder
axis, bArguments, dArguments, iArguments, inplaceCall, inputArguments, outputArguments, outputVariables, tArguments
dimensions, extraArgs, inPlace, sameDiff, scalarValue
Constructor and Description |
---|
LSTMBlock() |
LSTMBlock(INDArray x,
INDArray cLast,
INDArray yLast,
INDArray maxTSLength,
LSTMWeights lstmWeights,
LSTMConfiguration lstmConfiguration) |
LSTMBlock(@NonNull SameDiff sameDiff,
SDVariable maxTSLength,
SDVariable x,
SDVariable cLast,
SDVariable yLast,
LSTMWeights weights,
LSTMConfiguration configuration) |
Modifier and Type | Method and Description |
---|---|
List<DataType> |
calculateOutputDataTypes(List<DataType> inputDataTypes)
Calculate the data types for the output arrays.
|
List<SDVariable> |
doDiff(List<SDVariable> grads)
The actual implementation for automatic differentiation.
|
void |
initFromTensorFlow(NodeDef nodeDef,
SameDiff initWith,
Map<String,AttrValue> attributesForNode,
GraphDef graph)
Initialize the function from the given
NodeDef |
String |
opName()
This method returns op opName as string
|
Map<String,Object> |
propertiesForFunction()
Returns the properties for a given function
|
String |
tensorflowName()
The opName of this function tensorflow
|
addBArgument, addDArgument, addIArgument, addIArgument, addInputArgument, addOutputArgument, addTArgument, assertValidForExecution, bArgs, builder, calculateOutputShape, calculateOutputShape, clearArrays, dArgs, getBArgument, getDescriptor, getIArgument, getInputArgument, getOutputArgument, getTArgument, iArgs, initFromOnnx, inputArguments, numBArguments, numDArguments, numIArguments, numInputArguments, numOutputArguments, numTArguments, onnxName, opHash, opNum, opType, outputArguments, outputVariables, outputVariables, removeIArgument, removeInputArgument, removeOutputArgument, removeTArgument, setInputArgument, setInputArguments, setOutputArgument, tArgs, toString, wrapFilterNull, wrapOrNull, wrapOrNull
arg, arg, argNames, args, attributeAdaptersForFunction, configFieldName, diff, dup, equals, getNumOutputs, getValue, hashCode, isConfigProperties, larg, mappingsForFunction, onnxNames, outputs, outputVariable, outputVariablesNames, rarg, replaceArg, setInstanceId, setPropertiesForFunction, setValueFor, tensorflowNames
clone, finalize, getClass, notify, notifyAll, wait, wait, wait
isInplaceCall
public LSTMBlock()
public LSTMBlock(@NonNull @NonNull SameDiff sameDiff, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights weights, LSTMConfiguration configuration)
public LSTMBlock(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMWeights lstmWeights, LSTMConfiguration lstmConfiguration)
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes)
DifferentialFunction
DifferentialFunction.calculateOutputShape()
, this method differs in that it does not
require the input arrays to be populated.
This is important as it allows us to do greedy datatype inference for the entire net - even if arrays are not
available.calculateOutputDataTypes
in class DifferentialFunction
inputDataTypes
- The data types of the inputspublic List<SDVariable> doDiff(List<SDVariable> grads)
DifferentialFunction
doDiff
in class DynamicCustomOp
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String,AttrValue> attributesForNode, GraphDef graph)
DifferentialFunction
NodeDef
initFromTensorFlow
in class DynamicCustomOp
public String opName()
DynamicCustomOp
opName
in interface CustomOp
opName
in class DynamicCustomOp
public Map<String,Object> propertiesForFunction()
DifferentialFunction
propertiesForFunction
in class DifferentialFunction
public String tensorflowName()
DifferentialFunction
tensorflowName
in class DynamicCustomOp
Copyright © 2020. All rights reserved.