public class ReverseTimeSeriesVertex extends GraphVertex
Masks: The input might be masked (to allow for varying time series lengths in one minibatch). In this case the
present input (mask array = 1) will be reverted in place and the padding (mask array = 0) will be left untouched at
the same place. For a time series of length n, this would normally mean, that the first n time steps are reverted and
the following padding is left untouched, but more complex masks are supported (e.g. [1, 0, 1, 0, ...].
Note: In order to use mask arrays, the constructor
must be called with
the name of an network input. The mask of this input is then used in this vertex, too.
Constructor and Description |
---|
ReverseTimeSeriesVertex()
Creates a new ReverseTimeSeriesVertex that doesn't pay attention to masks
|
ReverseTimeSeriesVertex(String maskArrayInputName)
Creates a new ReverseTimeSeriesVertex that uses the mask array of a given input
|
Modifier and Type | Method and Description |
---|---|
ReverseTimeSeriesVertex |
clone() |
boolean |
equals(Object o) |
MemoryReport |
getMemoryReport(InputType... inputTypes)
This is a report of the estimated memory consumption for the given vertex
|
InputType |
getOutputType(int layerIndex,
InputType... vertexInputs)
Determine the type of output for this GraphVertex, given the specified inputs.
|
int |
hashCode() |
ReverseTimeSeriesVertex |
instantiate(ComputationGraph graph,
String name,
int idx,
INDArray paramsView,
boolean initializeParams,
org.nd4j.linalg.api.buffer.DataType networkDatatype)
Create a
GraphVertex instance, for the given computation graph,
given the configuration instance. |
int |
maxVertexInputs() |
int |
minVertexInputs() |
long |
numParams(boolean backprop) |
String |
toString() |
setDataType
public ReverseTimeSeriesVertex()
public ReverseTimeSeriesVertex(String maskArrayInputName)
maskArrayInputName
- The name of the input that holds the mask.public ReverseTimeSeriesVertex clone()
clone
in class GraphVertex
public boolean equals(Object o)
equals
in class GraphVertex
public int hashCode()
hashCode
in class GraphVertex
public long numParams(boolean backprop)
numParams
in class GraphVertex
public int minVertexInputs()
minVertexInputs
in class GraphVertex
public int maxVertexInputs()
maxVertexInputs
in class GraphVertex
public ReverseTimeSeriesVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, org.nd4j.linalg.api.buffer.DataType networkDatatype)
GraphVertex
GraphVertex
instance, for the given computation graph,
given the configuration instance.instantiate
in class GraphVertex
graph
- The computation graph that this GraphVertex is to be part ofname
- The name of the GraphVertex objectidx
- The index of the GraphVertexparamsView
- A view of the full parameters arrayinitializeParams
- If true: initialize the parameters. If false: make no change to the values in the paramsView arraypublic InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException
GraphVertex
getOutputType
in class GraphVertex
layerIndex
- The index of the layer (if appropriate/necessary).vertexInputs
- The inputs to this vertexInvalidInputTypeException
- If the input type is invalid for this type of GraphVertexpublic MemoryReport getMemoryReport(InputType... inputTypes)
GraphVertex
getMemoryReport
in class GraphVertex
inputTypes
- Input types to the vertex. Memory consumption is often a function of the input typeCopyright © 2019. All rights reserved.