public class TFGraphMapper extends BaseGraphMapper<GraphDef,NodeDef,AttrValue,NodeDef>
Modifier and Type | Field and Description |
---|---|
static String |
SHAPE_KEY |
static String |
VALUE_ATTR_KEY |
Modifier and Type | Method and Description |
---|---|
boolean |
alreadySeen(NodeDef nodeDef) |
DataBuffer.Type |
dataTypeForTensor(NodeDef tensorProto) |
void |
dumpBinaryProtoAsText(File inputFile,
File outputFile)
Dump a binary proto file representation as a
plain string in to the target text file
|
void |
dumpBinaryProtoAsText(InputStream inputFile,
File outputFile)
Dump a binary proto file representation as a
plain string in to the target text file
|
INDArray |
getArrayFrom(NodeDef nodeDef,
GraphDef graph) |
Map<String,AttrValue> |
getAttrMap(NodeDef nodeDef)
Get the attribute
map for given node
|
String |
getAttrValueFromNode(NodeDef nodeDef,
String key) |
String |
getInputFromNode(NodeDef node,
int index)
Get the input node for the given node
|
static TFGraphMapper |
getInstance()
Singleton.
|
DifferentialFunction |
getMappedOp(String name)
Get the mapped op name
for a given op
relative to the type of node being mapped.
|
String |
getName(NodeDef nodeDef)
Get the name of the node
|
INDArray |
getNDArrayFromTensor(String tensorName,
NodeDef node,
GraphDef graph) |
com.github.os72.protobuf351.Message.Builder |
getNewGraphBuilder()
Returns a graph builder for initial definition and parsing.
|
List<NodeDef> |
getNodeList(GraphDef graphDef) |
String |
getNodeName(String name)
Map a tensorflow node name
to the samediff equivalent
for import
|
NodeDef |
getNodeWithNameFromGraph(GraphDef graph,
String name)
Get the node from the graph
|
String |
getOpType(NodeDef nodeDef) |
long[] |
getShape(NodeDef nodeDef) |
long[] |
getShapeFromAttr(AttrValue attr)
Get the shape of the attribute value
|
long[] |
getShapeFromAttribute(AttrValue attrValue) |
long[] |
getShapeFromTensor(NodeDef tensorProto)
Get the shape for the given tensor type
|
String |
getTargetMappingForOp(DifferentialFunction function,
NodeDef node)
Get the target mapping key (usually based on the node name)
for the given function
|
boolean |
hasShape(NodeDef nodeDef) |
protected void |
importCondition(String conditionName,
NodeDef tfNode,
ImportState<GraphDef,NodeDef> importState) |
void |
initFunctionFromProperties(DifferentialFunction on,
Map<String,AttrValue> attributesForNode,
NodeDef node,
GraphDef graph)
|
void |
initFunctionFromProperties(String mappedTfName,
DifferentialFunction on,
Map<String,AttrValue> attributesForNode,
NodeDef node,
GraphDef graph)
Init a function's attributes
|
boolean |
isOpIgnoreException(NodeDef node)
Returns true if this node is a special case
(maybe because of name or other scenarios)
that should override
GraphMapper.opsToIgnore()
in certain circumstances |
boolean |
isPlaceHolder(NodeDef nodeDef)
Returns true if the given node is a place holder type
(think a yet to be determined shape)_
|
boolean |
isPlaceHolderNode(NodeDef node)
Returns true if the given node is a place holder
|
boolean |
isVariableNode(NodeDef nodeDef) |
void |
mapNodeType(NodeDef tfNode,
ImportState<GraphDef,NodeDef> importState)
Map a node in to the import state covering
the
SameDiff instance |
void |
mapProperty(String name,
DifferentialFunction on,
NodeDef node,
GraphDef graph,
SameDiff sameDiff,
Map<String,Map<String,PropertyMapping>> propertyMappingsForFunction) |
INDArray |
mapTensorProto(TensorProto tfTensor) |
IfImportState |
nodesForIf(NodeDef from,
GraphDef graph)
Returns the node for an if statement
|
int |
numInputsFor(NodeDef nodeDef)
Get the number of inputs for a node.
|
Set<String> |
opsToIgnore()
Ops to ignore for mapping
|
GraphDef |
parseGraphFrom(byte[] inputStream)
Parse a graph from an input stream
|
GraphDef |
parseGraphFrom(InputStream inputStream)
Parse a graph from an input stream
|
boolean |
shouldSkip(NodeDef opType) |
String |
translateToSameDiffName(String name,
NodeDef node) |
Map<String,NodeDef> |
variablesForGraph(GraphDef graphDef)
Get the variables for the given graph
|
importGraph, importGraph, importGraph, importGraph, mapProperties, nameIndexForGraph, nodesByName, opTypeForNode, readGraph, validTensorDataType
public static final String VALUE_ATTR_KEY
public static final String SHAPE_KEY
public static TFGraphMapper getInstance()
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile)
GraphMapper
public boolean isOpIgnoreException(NodeDef node)
GraphMapper
GraphMapper.opsToIgnore()
in certain circumstancesnode
- the node to checkpublic String getTargetMappingForOp(DifferentialFunction function, NodeDef node)
GraphMapper
function
- the functionnode
- the node to derive the target mapping frompublic NodeDef getNodeWithNameFromGraph(GraphDef graph, String name)
GraphMapper
graph
- the graph to get the node fromname
- the name of the node to get from the graphpublic void mapProperty(String name, DifferentialFunction on, NodeDef node, GraphDef graph, SameDiff sameDiff, Map<String,Map<String,PropertyMapping>> propertyMappingsForFunction)
public boolean isPlaceHolderNode(NodeDef node)
node
- the node to checkpublic void dumpBinaryProtoAsText(File inputFile, File outputFile)
public long[] getShapeFromAttr(AttrValue attr)
GraphMapper
attr
- the attribute valuepublic Map<String,AttrValue> getAttrMap(NodeDef nodeDef)
GraphMapper
nodeDef
- the nodepublic String getName(NodeDef nodeDef)
GraphMapper
nodeDef
- the node
to get the name forpublic boolean alreadySeen(NodeDef nodeDef)
public boolean isVariableNode(NodeDef nodeDef)
public boolean shouldSkip(NodeDef opType)
public boolean hasShape(NodeDef nodeDef)
public long[] getShape(NodeDef nodeDef)
public DifferentialFunction getMappedOp(String name)
GraphMapper
name
- the tensorflow or onnx namepublic String getNodeName(String name)
name
- the name to changepublic Map<String,NodeDef> variablesForGraph(GraphDef graphDef)
GraphMapper
graphDef
- the graph to get the variables forpublic com.github.os72.protobuf351.Message.Builder getNewGraphBuilder()
GraphMapper
public GraphDef parseGraphFrom(byte[] inputStream) throws IOException
GraphMapper
inputStream
- the input stream to load fromIOException
public GraphDef parseGraphFrom(InputStream inputStream) throws IOException
GraphMapper
inputStream
- the input stream to load fromIOException
protected void importCondition(String conditionName, NodeDef tfNode, ImportState<GraphDef,NodeDef> importState)
public void mapNodeType(NodeDef tfNode, ImportState<GraphDef,NodeDef> importState)
GraphMapper
SameDiff
instancetfNode
- the node to mapimportState
- the current import statepublic void initFunctionFromProperties(DifferentialFunction on, Map<String,AttrValue> attributesForNode, NodeDef node, GraphDef graph)
initFunctionFromProperties(DifferentialFunction, Map, NodeDef, GraphDef)
using DifferentialFunction.tensorflowName()
on
- the function to use init onattributesForNode
- the attributes for the nodenode
- graph
- public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String,AttrValue> attributesForNode, NodeDef node, GraphDef graph)
mappedTfName
- the tensorflow name to pick (sometimes ops have multiple nameson
- the function to mapattributesForNode
- the attributes for the nodenode
- graph
- public DataBuffer.Type dataTypeForTensor(NodeDef tensorProto)
public long[] getShapeFromAttribute(AttrValue attrValue)
public boolean isPlaceHolder(NodeDef nodeDef)
GraphMapper
public INDArray getNDArrayFromTensor(String tensorName, NodeDef node, GraphDef graph)
public INDArray mapTensorProto(TensorProto tfTensor)
public long[] getShapeFromTensor(NodeDef tensorProto)
GraphMapper
public Set<String> opsToIgnore()
GraphMapper
public String getInputFromNode(NodeDef node, int index)
GraphMapper
node
- the nodeindex
- hte indexpublic int numInputsFor(NodeDef nodeDef)
GraphMapper
nodeDef
- the node to get the number of inputs forpublic IfImportState nodesForIf(NodeDef from, GraphDef graph)
from
- the starting node (a merge node that represents a conditional)graph
- the graph to searchCopyright © 2018. All rights reserved.