public class OnnxGraphMapper extends BaseGraphMapper<OnnxProto3.GraphProto,OnnxProto3.NodeProto,OnnxProto3.AttributeProto,OnnxProto3.TypeProto.Tensor>
SameDiff
instances.Constructor and Description |
---|
OnnxGraphMapper() |
Modifier and Type | Method and Description |
---|---|
protected void |
addDummyTensor(String name,
Map<String,OnnxProto3.TypeProto.Tensor> to) |
boolean |
alreadySeen(OnnxProto3.NodeProto nodeProto) |
DataType |
dataTypeForTensor(OnnxProto3.TypeProto.Tensor tensorProto,
int outputNum) |
void |
dumpBinaryProtoAsText(File inputFile,
File outputFile)
Dump a binary proto file representation as a
plain string in to the target text file
|
void |
dumpBinaryProtoAsText(InputStream inputFile,
File outputFile)
Dump a binary proto file representation as a
plain string in to the target text file
|
INDArray |
getArrayFrom(OnnxProto3.NodeProto nodeProto,
OnnxProto3.GraphProto graph) |
Map<String,OnnxProto3.AttributeProto> |
getAttrMap(OnnxProto3.NodeProto nodeProto)
Get the attribute
map for given node
|
String |
getAttrValueFromNode(OnnxProto3.NodeProto nodeProto,
String key) |
List<String> |
getControlDependencies(OnnxProto3.NodeProto node)
Get the list of control dependencies for the current node (or null if none exist)
|
String |
getInputFromNode(OnnxProto3.NodeProto node,
int index)
Get the input node for the given node
|
static OnnxGraphMapper |
getInstance() |
DifferentialFunction |
getMappedOp(String name)
Get the mapped op name
for a given op
relative to the type of node being mapped.
|
String |
getName(OnnxProto3.NodeProto nodeProto)
Get the name of the node
|
INDArray |
getNDArrayFromTensor(String tensorName,
OnnxProto3.TypeProto.Tensor tensorProto,
OnnxProto3.GraphProto graph) |
com.github.os72.protobuf351.Message.Builder |
getNewGraphBuilder()
Returns a graph builder for initial definition and parsing.
|
List<OnnxProto3.NodeProto> |
getNodeList(OnnxProto3.GraphProto graphProto) |
OnnxProto3.NodeProto |
getNodeWithNameFromGraph(OnnxProto3.GraphProto graph,
String name)
Get the node from the graph
|
String |
getOpType(OnnxProto3.NodeProto nodeProto) |
long[] |
getShape(OnnxProto3.NodeProto nodeProto) |
long[] |
getShapeFromAttr(OnnxProto3.AttributeProto attr)
Get the shape of the attribute value
|
long[] |
getShapeFromAttribute(OnnxProto3.AttributeProto attributeProto) |
long[] |
getShapeFromTensor(OnnxProto3.TensorProto tensorProto)
Get the shape from a tensor proto.
|
long[] |
getShapeFromTensor(OnnxProto3.TypeProto.Tensor tensorProto)
Get the shape for the given tensor type
|
String |
getTargetMappingForOp(DifferentialFunction function,
OnnxProto3.NodeProto node)
Get the target mapping key (usually based on the node name)
for the given function
|
boolean |
hasShape(OnnxProto3.NodeProto nodeProto) |
void |
initFunctionFromProperties(String mappedTfName,
DifferentialFunction on,
Map<String,OnnxProto3.AttributeProto> attributesForNode,
OnnxProto3.NodeProto node,
OnnxProto3.GraphProto graph)
Init a function's attributes
|
boolean |
isConstant(OnnxProto3.TypeProto.Tensor nodeType)
Returns true if the given node is a constant
|
boolean |
isOpIgnoreException(OnnxProto3.NodeProto node)
Returns true if this node is a special case
(maybe because of name or other scenarios)
that should override
GraphMapper.opsToIgnore()
in certain circumstances |
boolean |
isPlaceHolder(OnnxProto3.TypeProto.Tensor nodeType)
Returns true if the given node is a place holder type
(think a yet to be determined shape)_
|
boolean |
isPlaceHolderNode(OnnxProto3.TypeProto.Tensor node)
Returns true if the given node is a place holder
|
boolean |
isStringType(OnnxProto3.TypeProto.Tensor tensor) |
boolean |
isVariableNode(OnnxProto3.NodeProto nodeProto) |
void |
mapNodeType(OnnxProto3.NodeProto tfNode,
ImportState<OnnxProto3.GraphProto,OnnxProto3.TypeProto.Tensor> importState,
OpImportOverride<OnnxProto3.GraphProto,OnnxProto3.NodeProto,OnnxProto3.AttributeProto> opImportOverride,
OpImportFilter<OnnxProto3.GraphProto,OnnxProto3.NodeProto,OnnxProto3.AttributeProto> opFilter)
Map a node in to the import state covering the
SameDiff instance |
void |
mapProperty(String name,
DifferentialFunction on,
OnnxProto3.NodeProto node,
OnnxProto3.GraphProto graph,
SameDiff sameDiff,
Map<String,Map<String,PropertyMapping>> propertyMappingsForFunction) |
INDArray |
mapTensorProto(OnnxProto3.TensorProto tensor) |
DataType |
nd4jTypeFromOnnxType(OnnxProto3.TensorProto.DataType dataType)
Convert an onnx type to the proper nd4j type
|
int |
numInputsFor(OnnxProto3.NodeProto nodeProto)
Get the number of inputs for a node.
|
Set<String> |
opsToIgnore()
Ops to ignore for mapping
|
OnnxProto3.GraphProto |
parseGraphFrom(byte[] inputStream)
Parse a graph from an input stream
|
OnnxProto3.GraphProto |
parseGraphFrom(InputStream inputStream)
Parse a graph from an input stream
|
boolean |
shouldSkip(OnnxProto3.NodeProto opType) |
String |
translateToSameDiffName(String name,
OnnxProto3.NodeProto node) |
Map<String,OnnxProto3.TypeProto.Tensor> |
variablesForGraph(OnnxProto3.GraphProto graphProto)
Get the variables for the given graph
|
importGraph, importGraph, importGraph, importGraph, importGraph, importGraph, initOutputVariables, mapProperties, nameIndexForGraph, nodesByName, opTypeForNode, readGraph, validateGraphStructure, validTensorDataType
public static OnnxGraphMapper getInstance()
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile)
GraphMapper
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String,OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph)
mappedTfName
- the onnx name to pick (sometimes ops have multiple nameson
- the function to mapattributesForNode
- the attributes for the nodenode
- graph
- public boolean isOpIgnoreException(OnnxProto3.NodeProto node)
GraphMapper
GraphMapper.opsToIgnore()
in certain circumstancesnode
- the node to checkpublic String getTargetMappingForOp(DifferentialFunction function, OnnxProto3.NodeProto node)
GraphMapper
function
- the functionnode
- the node to derive the target mapping frompublic void mapProperty(String name, DifferentialFunction on, OnnxProto3.NodeProto node, OnnxProto3.GraphProto graph, SameDiff sameDiff, Map<String,Map<String,PropertyMapping>> propertyMappingsForFunction)
public OnnxProto3.NodeProto getNodeWithNameFromGraph(OnnxProto3.GraphProto graph, String name)
GraphMapper
graph
- the graph to get the node fromname
- the name of the node to get from the graphpublic boolean isPlaceHolderNode(OnnxProto3.TypeProto.Tensor node)
GraphMapper
node
- the node to checkpublic List<String> getControlDependencies(OnnxProto3.NodeProto node)
GraphMapper
node
- Node to get the control dependencies (if any) forpublic void dumpBinaryProtoAsText(File inputFile, File outputFile)
GraphMapper
public DifferentialFunction getMappedOp(String name)
GraphMapper
name
- the tensorflow or onnx namepublic Map<String,OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto)
GraphMapper
graphProto
- the graph to get the variables forpublic String translateToSameDiffName(String name, OnnxProto3.NodeProto node)
protected void addDummyTensor(String name, Map<String,OnnxProto3.TypeProto.Tensor> to)
public com.github.os72.protobuf351.Message.Builder getNewGraphBuilder()
GraphMapper
public OnnxProto3.GraphProto parseGraphFrom(byte[] inputStream) throws IOException
GraphMapper
inputStream
- the input stream to load fromIOException
public OnnxProto3.GraphProto parseGraphFrom(InputStream inputStream) throws IOException
GraphMapper
inputStream
- the input stream to load fromIOException
public void mapNodeType(OnnxProto3.NodeProto tfNode, ImportState<OnnxProto3.GraphProto,OnnxProto3.TypeProto.Tensor> importState, OpImportOverride<OnnxProto3.GraphProto,OnnxProto3.NodeProto,OnnxProto3.AttributeProto> opImportOverride, OpImportFilter<OnnxProto3.GraphProto,OnnxProto3.NodeProto,OnnxProto3.AttributeProto> opFilter)
GraphMapper
SameDiff
instancetfNode
- the node to mapimportState
- the current import stateopFilter
- Optional filter for skipping operationspublic DataType dataTypeForTensor(OnnxProto3.TypeProto.Tensor tensorProto, int outputNum)
public boolean isStringType(OnnxProto3.TypeProto.Tensor tensor)
public DataType nd4jTypeFromOnnxType(OnnxProto3.TensorProto.DataType dataType)
dataType
- the data type to convertpublic String getAttrValueFromNode(OnnxProto3.NodeProto nodeProto, String key)
public long[] getShapeFromAttribute(OnnxProto3.AttributeProto attributeProto)
public boolean isPlaceHolder(OnnxProto3.TypeProto.Tensor nodeType)
GraphMapper
public boolean isConstant(OnnxProto3.TypeProto.Tensor nodeType)
GraphMapper
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph)
public INDArray mapTensorProto(OnnxProto3.TensorProto tensor)
public long[] getShapeFromTensor(OnnxProto3.TypeProto.Tensor tensorProto)
GraphMapper
public long[] getShapeFromTensor(OnnxProto3.TensorProto tensorProto)
getShapeFromTensor(OnnxProto3.TensorProto)
tensorProto
- the tensor to get the shape frompublic Set<String> opsToIgnore()
GraphMapper
public String getInputFromNode(OnnxProto3.NodeProto node, int index)
GraphMapper
node
- the nodeindex
- hte indexpublic int numInputsFor(OnnxProto3.NodeProto nodeProto)
GraphMapper
nodeProto
- the node to get the number of inputs forpublic long[] getShapeFromAttr(OnnxProto3.AttributeProto attr)
GraphMapper
attr
- the attribute valuepublic Map<String,OnnxProto3.AttributeProto> getAttrMap(OnnxProto3.NodeProto nodeProto)
GraphMapper
nodeProto
- the nodepublic String getName(OnnxProto3.NodeProto nodeProto)
GraphMapper
nodeProto
- the node
to get the name forpublic boolean alreadySeen(OnnxProto3.NodeProto nodeProto)
public boolean isVariableNode(OnnxProto3.NodeProto nodeProto)
public boolean shouldSkip(OnnxProto3.NodeProto opType)
public boolean hasShape(OnnxProto3.NodeProto nodeProto)
public long[] getShape(OnnxProto3.NodeProto nodeProto)
public INDArray getArrayFrom(OnnxProto3.NodeProto nodeProto, OnnxProto3.GraphProto graph)
public String getOpType(OnnxProto3.NodeProto nodeProto)
public List<OnnxProto3.NodeProto> getNodeList(OnnxProto3.GraphProto graphProto)
Copyright © 2019. All rights reserved.