GRAPH_TYPE
- the proto type for the graphNODE_TYPE
- the proto type for the nodeATTR_TYPE
- the proto type for the attributeTENSOR_TYPE
- the proto type for the tensorpublic interface GraphMapper<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE,TENSOR_TYPE>
SameDiff
instancesModifier and Type | Method and Description |
---|---|
boolean |
alreadySeen(NODE_TYPE nodeType) |
DataBuffer.Type |
dataTypeForTensor(TENSOR_TYPE tensorType) |
void |
dumpBinaryProtoAsText(File inputFile,
File outputFile)
Dump a binary proto file representation as a
plain string in to the target text file
|
void |
dumpBinaryProtoAsText(InputStream inputFile,
File outputFile)
Dump a binary proto file representation as a
plain string in to the target text file
|
INDArray |
getArrayFrom(NODE_TYPE nodeType,
GRAPH_TYPE graph) |
Map<String,ATTR_TYPE> |
getAttrMap(NODE_TYPE nodeType)
Get the attribute
map for given node
|
String |
getAttrValueFromNode(NODE_TYPE nodeType,
String key) |
String |
getInputFromNode(NODE_TYPE node,
int index)
Get the input node for the given node
|
DifferentialFunction |
getMappedOp(String name)
Get the mapped op name
for a given op
relative to the type of node being mapped.
|
String |
getName(NODE_TYPE nodeType)
Get the name of the node
|
INDArray |
getNDArrayFromTensor(String tensorName,
TENSOR_TYPE tensorType,
GRAPH_TYPE graph) |
com.github.os72.protobuf351.Message.Builder |
getNewGraphBuilder()
Returns a graph builder for initial definition and parsing.
|
List<NODE_TYPE> |
getNodeList(GRAPH_TYPE graphType) |
NODE_TYPE |
getNodeWithNameFromGraph(GRAPH_TYPE graph,
String name)
Get the node from the graph
|
String |
getOpType(NODE_TYPE nodeType) |
int[] |
getShape(NODE_TYPE nodeType) |
int[] |
getShapeFromAttr(ATTR_TYPE attr)
Get the shape of the attribute value
|
int[] |
getShapeFromAttribute(ATTR_TYPE attrType) |
int[] |
getShapeFromTensor(TENSOR_TYPE tensorType)
Get the shape for the given tensor type
|
String |
getTargetMappingForOp(DifferentialFunction function,
NODE_TYPE node)
Get the target mapping key (usually based on the node name)
for the given function
|
boolean |
hasShape(NODE_TYPE nodeType) |
SameDiff |
importGraph(File graphFile)
Import a graph as same diff
from the given file
|
SameDiff |
importGraph(GRAPH_TYPE tfGraph)
This method converts given TF
|
SameDiff |
importGraph(InputStream graphFile)
Import a graph as same diff
from the given file
|
boolean |
isOpIgnoreException(NODE_TYPE node)
Returns true if this node is a special case
(maybe because of name or other scenarios)
that should override
opsToIgnore()
in certain circumstances |
boolean |
isPlaceHolder(TENSOR_TYPE nodeType)
Returns true if the given node is a place holder type
(think a yet to be determined shape)_
|
boolean |
isPlaceHolderNode(TENSOR_TYPE node)
Returns true if the given node is a place holder
|
boolean |
isVariableNode(NODE_TYPE nodeType) |
void |
mapNodeType(NODE_TYPE tfNode,
ImportState<GRAPH_TYPE,TENSOR_TYPE> importState)
Map a node in to the import state covering
the
SameDiff instance |
void |
mapProperties(DifferentialFunction on,
NODE_TYPE node,
GRAPH_TYPE graph,
SameDiff sameDiff,
Map<String,Map<String,PropertyMapping>> propertyMappings) |
void |
mapProperty(String name,
DifferentialFunction on,
NODE_TYPE node,
GRAPH_TYPE graph,
SameDiff sameDiff,
Map<String,Map<String,PropertyMapping>> propertyMappingsForFunction) |
Map<String,NODE_TYPE> |
nameIndexForGraph(GRAPH_TYPE graph) |
Map<String,NODE_TYPE> |
nodesByName(GRAPH_TYPE graph)
Get the nodes sorted by n ame
from a given graph
|
int |
numInputsFor(NODE_TYPE nodeType)
Get the number of inputs for a node.
|
Set<String> |
opsToIgnore()
Ops to ignore for mapping
|
Op.Type |
opTypeForNode(NODE_TYPE nodeType)
Returns an op type for the given input node
|
GRAPH_TYPE |
parseGraphFrom(byte[] inputStream)
Parse a graph from an input stream
|
GRAPH_TYPE |
parseGraphFrom(InputStream inputStream)
Parse a graph from an input stream
|
boolean |
shouldSkip(NODE_TYPE opType) |
String |
translateToSameDiffName(String name,
NODE_TYPE node) |
boolean |
validTensorDataType(TENSOR_TYPE tensorType)
Whether the data type for the tensor is valid
for creating an
INDArray |
Map<String,TENSOR_TYPE> |
variablesForGraph(GRAPH_TYPE graphType)
Get the variables for the given graph
|
boolean isOpIgnoreException(NODE_TYPE node)
opsToIgnore()
in certain circumstancesnode
- the node to checkMap<String,NODE_TYPE> nodesByName(GRAPH_TYPE graph)
graph
- the graph to get the nodes forString getTargetMappingForOp(DifferentialFunction function, NODE_TYPE node)
function
- the functionnode
- the node to derive the target mapping fromvoid mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String,Map<String,PropertyMapping>> propertyMappings)
on
- node
- graph
- sameDiff
- propertyMappings
- void mapProperty(String name, DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String,Map<String,PropertyMapping>> propertyMappingsForFunction)
name
- on
- node
- graph
- sameDiff
- propertyMappingsForFunction
- NODE_TYPE getNodeWithNameFromGraph(GRAPH_TYPE graph, String name)
graph
- the graph to get the node fromname
- the name of the node to get from the graphboolean isPlaceHolderNode(TENSOR_TYPE node)
node
- the node to checkvoid dumpBinaryProtoAsText(File inputFile, File outputFile)
inputFile
- outputFile
- void dumpBinaryProtoAsText(InputStream inputFile, File outputFile)
inputFile
- outputFile
- DifferentialFunction getMappedOp(String name)
name
- the tensorflow or onnx nameDifferentialFunctionClassHolder
Map<String,TENSOR_TYPE> variablesForGraph(GRAPH_TYPE graphType)
graphType
- the graph to get the variables forString translateToSameDiffName(String name, NODE_TYPE node)
name
- node
- Map<String,NODE_TYPE> nameIndexForGraph(GRAPH_TYPE graph)
graph
- Op.Type opTypeForNode(NODE_TYPE nodeType)
nodeType
- the node to usecom.github.os72.protobuf351.Message.Builder getNewGraphBuilder()
GRAPH_TYPE parseGraphFrom(byte[] inputStream) throws IOException
inputStream
- the input stream to load fromIOException
GRAPH_TYPE parseGraphFrom(InputStream inputStream) throws IOException
inputStream
- the input stream to load fromIOException
void mapNodeType(NODE_TYPE tfNode, ImportState<GRAPH_TYPE,TENSOR_TYPE> importState)
SameDiff
instancetfNode
- the node to mapimportState
- the current import stateDataBuffer.Type dataTypeForTensor(TENSOR_TYPE tensorType)
tensorType
- String getAttrValueFromNode(NODE_TYPE nodeType, String key)
nodeType
- key
- int[] getShapeFromAttribute(ATTR_TYPE attrType)
attrType
- boolean isPlaceHolder(TENSOR_TYPE nodeType)
nodeType
- INDArray getNDArrayFromTensor(String tensorName, TENSOR_TYPE tensorType, GRAPH_TYPE graph)
tensorName
- tensorType
- graph
- int[] getShapeFromTensor(TENSOR_TYPE tensorType)
tensorType
- String getInputFromNode(NODE_TYPE node, int index)
node
- the nodeindex
- hte indexint numInputsFor(NODE_TYPE nodeType)
nodeType
- the node to get the number of inputs forboolean validTensorDataType(TENSOR_TYPE tensorType)
INDArray
tensorType
- the tensor proto to testint[] getShapeFromAttr(ATTR_TYPE attr)
attr
- the attribute valueMap<String,ATTR_TYPE> getAttrMap(NODE_TYPE nodeType)
nodeType
- the nodeString getName(NODE_TYPE nodeType)
nodeType
- the node
to get the name forboolean alreadySeen(NODE_TYPE nodeType)
nodeType
- boolean isVariableNode(NODE_TYPE nodeType)
nodeType
- boolean shouldSkip(NODE_TYPE opType)
opType
- boolean hasShape(NODE_TYPE nodeType)
nodeType
- int[] getShape(NODE_TYPE nodeType)
nodeType
- INDArray getArrayFrom(NODE_TYPE nodeType, GRAPH_TYPE graph)
nodeType
- graph
- List<NODE_TYPE> getNodeList(GRAPH_TYPE graphType)
graphType
- SameDiff importGraph(InputStream graphFile)
graphFile
- SameDiff importGraph(File graphFile)
graphFile
- SameDiff importGraph(GRAPH_TYPE tfGraph)
tfGraph
- Copyright © 2018. All rights reserved.