public abstract class BaseRecurrentLayer<LayerConfT extends BaseRecurrentLayer> extends BaseLayer<LayerConfT> implements RecurrentLayer
Layer.TrainingMode, Layer.Type| Modifier and Type | Field and Description |
|---|---|
protected int |
helperCountFail |
protected Map<String,INDArray> |
stateMap
stateMap stores the INDArrays needed to do rnnTimeStep() forward pass.
|
protected Map<String,INDArray> |
tBpttStateMap
State map for use specifically in truncated BPTT training.
|
gradient, gradientsFlattened, gradientViews, optimizer, params, paramsFlattened, score, solver, weightNoiseParamscacheMode, conf, dataType, dropoutApplied, epochCount, index, input, inputModificationAllowed, iterationCount, maskArray, maskState, preOutput, trainingListeners| Constructor and Description |
|---|
BaseRecurrentLayer(NeuralNetConfiguration conf,
DataType dataType) |
| Modifier and Type | Method and Description |
|---|---|
RNNFormat |
getDataFormat() |
protected INDArray |
permuteIfNWC(INDArray arr) |
void |
rnnClearPreviousState()
Reset/clear the stateMap for rnnTimeStep() and tBpttStateMap for rnnActivateUsingStoredState()
|
Map<String,INDArray> |
rnnGetPreviousState()
Returns a shallow copy of the stateMap
|
Map<String,INDArray> |
rnnGetTBPTTState()
Get the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
void |
rnnSetPreviousState(Map<String,INDArray> stateMap)
Set the state map.
|
void |
rnnSetTBPTTState(Map<String,INDArray> state)
Set the RNN truncated backpropagations through time (TBPTT) state for the recurrent layer.
|
activate, backpropGradient, calcRegularizationScore, clear, clearNoiseWeightParams, clone, computeGradientAndScore, fit, fit, getGradientsViewArray, getOptimizer, getParam, getParamWithNoise, gradient, hasBias, hasLayerNorm, layerConf, numParams, params, paramTable, paramTable, preOutput, preOutputWithPreNorm, score, setBackpropGradientsViewArray, setParam, setParams, setParams, setParamsViewArray, setParamTable, setScoreWithZ, toString, update, updateactivate, addListeners, allowInputModification, applyConstraints, applyDropOutIfNecessary, applyMask, assertInputSet, backpropDropOutIfPresent, batchSize, close, conf, feedForwardMaskArray, getConfig, getEpochCount, getHelper, getIndex, getInput, getInputMiniBatchSize, getListeners, getMaskArray, gradientAndScore, init, input, layerId, numParams, setCacheMode, setConf, setEpochCount, setIndex, setInput, setInputMiniBatchSize, setListeners, setListeners, setMaskArray, type, updaterDivideByMinibatchequals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitrnnActivateUsingStoredState, rnnTimeStep, tbpttBackpropGradientactivate, activate, allowInputModification, backpropGradient, calcRegularizationScore, clearNoiseWeightParams, feedForwardMaskArray, getEpochCount, getHelper, getIndex, getInputMiniBatchSize, getIterationCount, getListeners, getMaskArray, isPretrainLayer, setCacheMode, setEpochCount, setIndex, setInput, setInputMiniBatchSize, setIterationCount, setListeners, setListeners, setMaskArray, typeaddListeners, applyConstraints, batchSize, clear, close, computeGradientAndScore, conf, fit, fit, getGradientsViewArray, getOptimizer, getParam, gradient, gradientAndScore, init, input, numParams, numParams, params, paramTable, paramTable, score, setBackpropGradientsViewArray, setConf, setParam, setParams, setParamsViewArray, setParamTable, update, updategetConfig, getGradientsViewArray, numParams, params, paramTable, updaterDivideByMinibatchprotected Map<String,INDArray> stateMap
protected Map<String,INDArray> tBpttStateMap
protected int helperCountFail
public BaseRecurrentLayer(NeuralNetConfiguration conf, DataType dataType)
public Map<String,INDArray> rnnGetPreviousState()
rnnGetPreviousState in interface RecurrentLayerpublic void rnnSetPreviousState(Map<String,INDArray> stateMap)
rnnSetPreviousState in interface RecurrentLayerpublic void rnnClearPreviousState()
rnnClearPreviousState in interface RecurrentLayerpublic Map<String,INDArray> rnnGetTBPTTState()
RecurrentLayerrnnGetTBPTTState in interface RecurrentLayerpublic void rnnSetTBPTTState(Map<String,INDArray> state)
RecurrentLayerrnnSetTBPTTState in interface RecurrentLayerstate - TBPTT state to setpublic RNNFormat getDataFormat()
Copyright © 2020. All rights reserved.