为什么要说 ONNX ,ONNX又是个什么东西,经常要部署神经网络应用的童鞋们可能会ONNX会比较熟悉,我们可能会在某一任务中将Pytorch或者TensorFlow模型转化为ONNX模型(ONNX模型一般用于中间部署阶段),然后再拿转化后的ONNX模型进而转化为我们使用不同框架部署需要的类型。


  • Pytorch -> ONNX -> TensorRT
  • Pytorch -> ONNX -> TVM
  • TF – onnx – ncnn

等等,ONNX相当于一个翻译的作用,这也是为什么ONNX叫做 Open Neural Network Exchange



比如就叫做 model.pt ,这个我们应该很熟悉吧,二进制的模型权重文件,我们可以读取这个文件,相当于预加载了权重信息。

那ONNX呢,利用Pytorch我们可以将 model.pt 转化为 model.onnx 格式的权重,在这里onnx充当一个后缀名称, model.onnx 就代表ONNX格式的权重文件,这个权重文件不仅包含了权重值,也包含了神经网络的网络流动信息以及每一层网络的输入输出信息和一些其他的辅助信息。

简单拿netron这个 工具 来可视化(读取ONNX文件)一下:


如图,ONNX中的一些信息都被可视化展示了出来,例如文件格式 ONNX v3 ,该文件的导出方 pytorch 0.4 等等,这些信息都保存在ONNX格式的文件中。





关于这个有一篇文章比较好地对此进行了介绍: https://www.ibm.com/developerworks/cn/linux/l-cn-gpb/index.html,本文就不进行详细的介绍了。


ONNX中最核心的就是 onnx.proto 这个文件,这个文件中定义了ONNX这个数据协议的规则和一些其他信息。

下面是 onnx.proto 文件,这个文件可以帮助我们了解ONNX到底包含了一些什么样的信息。


// Copyright (c) Facebook Inc. and Microsoft Corporation.
// Licensed under the MIT license.

syntax = "proto2";

// ... 省略了一部分

// Nodes
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
// For example, it can be a node of type "Conv" that takes in an image, a filter 
// tensor and a bias tensor, and produces the convolved output.

// Node就是神经网络中的一个个操作结点,例如conv、reshape、relu等之类的操作 

message NodeProto {
  repeated string input = 1;    // namespace Value
  repeated string output = 2;   // namespace Value

  // An optional identifier for this node in a graph.
  // This field MAY be absent in ths version of the IR.
  optional string name = 3;     // namespace Node

  // The symbolic identifier of the Operator to execute.
  optional string op_type = 4;  // namespace Operator
  // The domain of the OperatorSet that specifies the operator named by op_type.
  optional string domain = 7;   // namespace Domain

  // Additional named attributes.
  // attribute表示这个节点中的一些信息,对于conv结点来说,例如kernel大小、stride大小等
  repeated AttributeProto attribute = 5;

  // A human-readable documentation for this node. Markdown is allowed.
  optional string doc_string = 6;

// Models
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
// The semantics of the model are described by the associated GraphProto.
// Models作为最大的单位,包含了Graph以及一些其他版本信息
message ModelProto {
  // The version of the IR this model targets. See Version enum above.
  // This field MUST be present.
  optional int64 ir_version = 1;

  // The OperatorSets this model relies on.
  // All ModelProtos MUST have at least one entry that
  // specifies which version of the ONNX OperatorSet is
  // being imported.
  // All nodes in the ModelProto's graph will bind against the operator
  // with the same-domain/same-op_type operator with the HIGHEST version
  // in the referenced operator sets.
  repeated OperatorSetIdProto opset_import = 8;

  // The name of the framework or tool used to generate this model.
  // This field SHOULD be present to indicate which implementation/tool/framework
  // emitted the model.
  optional string producer_name = 2;

  // The version of the framework or tool used to generate this model.
  // This field SHOULD be present to indicate which implementation/tool/framework
  // emitted the model.
  optional string producer_version = 3;

  // Domain name of the model.
  // We use reverse domain names as name space indicators. For example:
  // com.facebook.fair or com.microsoft.cognitiveservices
  // Together with model_version and GraphProto.name, this forms the unique identity of
  // the graph.
  optional string domain = 4;

  // The version of the graph encoded. See Version enum below.
  optional int64 model_version = 5;

  // A human-readable documentation for this model. Markdown is allowed.
  optional string doc_string = 6;

  // The parameterized graph that is evaluated to execute the model.
  // 重要部分,graph即包含了网络信息的有向无环图
  optional GraphProto graph = 7;

  // Named metadata values; keys should be distinct.
  repeated StringStringEntryProto metadata_props = 14;

// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
  optional string key = 1;
  optional string value= 2;

// Graphs
// A graph defines the computational logic of a model and is comprised of a parameterized 
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
// Graphs是最重要的部分,里面包含了模型的构造和模型的权重等一切我们需要的信息
message GraphProto {
  // The nodes in the graph, sorted topologically.
  // 经过拓扑 排序 后的node,也就是结点,每个结点代表模型中的一个操作,例如conv
  repeated NodeProto node = 1;

  // The name of the graph.
  optional string name = 2;   // namespace Graph

  // A list of named tensor values, used to specify constant inputs of the graph.
  // Each TensorProto entry must have a distinct name (within the list) that
  // also appears in the input list.
  // initializer存储了模型中的所有参数,也就是我们平时所说的模型权重
  repeated TensorProto initializer = 5;

  // A human-readable documentation for this graph. Markdown is allowed.
  optional string doc_string = 10;

  // The inputs and outputs of the graph.
  repeated ValueInfoProto input = 11;   // 模型中所有的输入,包括最开始输入的图像以及每个结点的输入信息
  repeated ValueInfoProto output = 12;

  // Information for the values in the graph. The ValueInfoProto.name's
  // must be distinct. It is optional for a value to appear in value_info list.
  repeated ValueInfoProto value_info = 13;

  // DO NOT USE the following fields, they were deprecated from earlier versions.
  // repeated string input = 3;
  // repeated string output = 4;
  // optional int64 ir_version = 6;
  // optional int64 producer_version = 7;
  // optional string producer_tag = 8;
  // optional string domain = 9;

// Tensors
// A serialized tensor value.
message TensorProto {
  enum DataType {
    UNDEFINED = 0;
    // Basic types.
    FLOAT = 1;   // float
    UINT8 = 2;   // uint8_t
    INT8 = 3;    // int8_t
    UINT16 = 4;  // uint16_t
    INT16 = 5;   // int16_t
    INT32 = 6;   // int32_t
    INT64 = 7;   // int64_t
    STRING = 8;  // string
    BOOL = 9;    // bool

    // IEEE754 half-precision floating-point format (16 bits wide).
    // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
    FLOAT16 = 10;

    DOUBLE = 11;
    UINT32 = 12;
    UINT64 = 13;
    COMPLEX64 = 14;     // complex with float32 real and imaginary components
    COMPLEX128 = 15;    // complex with float64 real and imaginary components

    // Non-IEEE floating-point format based on IEEE754 single-precision
    // floating-point number truncated to 16 bits.
    // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
    BFLOAT16 = 16;

    // Future extensions go here.

  // The shape of the tensor.
  repeated int64 dims = 1;

  // The data type of the tensor.
  // This field MUST have a valid TensorProto.DataType value
  optional int32 data_type = 2;


// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
message TensorShapeProto {
  message Dimension {
    oneof value {
      int64 dim_value = 1;
      string dim_param = 2;   // namespace Shape
    // Standard denotation can optionally be used to denote tensor
    // dimensions with standard semantic descriptions to ensure
    // that operations are applied to the correct axis of a tensor.
    // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition
    // for pre-defined dimension denotations.
    optional string denotation = 3;
  repeated Dimension dim = 1;  // 该Tensor的维数 

// Operator Sets
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
message OperatorSetIdProto {
  // The domain of the operator set being identified.
  // The empty string ("") or absence of this field implies the operator
  // set that is defined as part of the ONNX specification.
  // This field MUST be present in this version of the IR when referring to any other operator set.
  optional string domain = 1;

  // The version of the operator set being identified.
  // This field MUST be present in this version of the IR.
  optional int64 version = 2;



ONNX IR version: 0.0.3
Opset version: 9
Producer name: pytorch
Producer version: 0.4



安装onnx很简单,我们只需要 pip onnx 即可,这样的同时也将protobuf安装。

在debug界面显示读取的数据交互格式信息,可以与之前介绍的 onnx.proto 文件相匹配,下面的信息即表示了一个神经网络的全部结构以及每个层(node)的操作层等各种信息:

ir_version: 1
producer_name: "pytorch"
producer_version: "0.2"
domain: "com.facebook"
graph {
  node {
    input: "1"
    input: "2"
    output: "11"
    op_type: "Conv"
    attribute {
      name: "kernel_shape"
      ints: 5
      ints: 5
    attribute {
      name: "strides"
      ints: 1
      ints: 1
    attribute {
      name: "pads"
      ints: 2
      ints: 2
      ints: 2
      ints: 2
    attribute {
      name: "dilations"
      ints: 1
      ints: 1
    attribute {
      name: "group"
      i: 1
  node {
    input: "11"
    input: "3"
    output: "12"
    op_type: "Add"
    attribute {
      name: "broadcast"
      i: 1
    attribute {
      name: "axis"
      i: 1
  node {
    input: "12"
    output: "13"
    op_type: "Relu"
  node {
    input: "13"
    input: "4"
    output: "15"
    op_type: "Conv"
    attribute {
      name: "kernel_shape"
      ints: 3
      ints: 3
    attribute {
      name: "strides"
      ints: 1
      ints: 1
    attribute {
      name: "pads"
      ints: 1
      ints: 1
      ints: 1
      ints: 1
    attribute {
      name: "dilations"
      ints: 1
      ints: 1
    attribute {
      name: "group"
      i: 1
  node {
    input: "15"
    input: "5"
    output: "16"
    op_type: "Add"
    attribute {
      name: "broadcast"
      i: 1
    attribute {
      name: "axis"
      i: 1
  node {
    input: "16"
    output: "17"
    op_type: "Relu"
  node {
    input: "17"
    input: "6"
    output: "19"
    op_type: "Conv"
    attribute {
      name: "kernel_shape"
      ints: 3
      ints: 3
    attribute {
      name: "strides"
      ints: 1
      ints: 1
    attribute {
      name: "pads"
      ints: 1
      ints: 1
      ints: 1
      ints: 1
    attribute {
      name: "dilations"
      ints: 1
      ints: 1
    attribute {
      name: "group"
      i: 1
  node {
    input: "19"
    input: "7"
    output: "20"
    op_type: "Add"
    attribute {
      name: "broadcast"
      i: 1
    attribute {
      name: "axis"
      i: 1
  node {
    input: "20"
    output: "21"
    op_type: "Relu"
  node {
    input: "21"
    input: "8"
    output: "23"
    op_type: "Conv"
    attribute {
      name: "kernel_shape"
      ints: 3
      ints: 3
    attribute {
      name: "strides"
      ints: 1
      ints: 1
    attribute {
      name: "pads"
      ints: 1
      ints: 1
      ints: 1
      ints: 1
    attribute {
      name: "dilations"
      ints: 1
      ints: 1
    attribute {
      name: "group"
      i: 1
  node {
    input: "23"
    input: "9"
    output: "24"
    op_type: "Add"
    attribute {
      name: "broadcast"
      i: 1
    attribute {
      name: "axis"
      i: 1
  node {
    input: "24"
    output: "25"
    op_type: "Reshape"
    attribute {
      name: "shape"
      ints: 1
      ints: 1
      ints: 3
      ints: 3
      ints: 224
      ints: 224
  node {
    input: "25"
    output: "26"
    op_type: "Transpose"
    attribute {
      name: "perm"
      ints: 0
      ints: 1
      ints: 4
      ints: 2
      ints: 5
      ints: 3
  node {
    input: "26"
    output: "27"
    op_type: "Reshape"
    attribute {
      name: "shape"
      ints: 1
      ints: 1
      ints: 672
      ints: 672
  name: "torch-jit-export"
  initializer {
    dims: 64
    dims: 1
    dims: 5
    dims: 5
    data_type: FLOAT
    name: "2"
    raw_data: "\034


其中model为pytorch的模型,example为输入, export_params=True 代表连带参数一并输出。

model = test_model()
state = torch.load('test.pth')
model.load_state_dict(state['model'], strict=True)

example = torch.rand(1, 3, 128, 128)

torch_out = torch.onnx.export(model,



# output = input.view(input.size(0), -1)
  output = input.view([int(input.size(0)), -1])


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网






