ONNX初探
0x0. 背景
最近看了一些ONNX的資料,一個(gè)最大的感受就是這些資料太凌亂了。大多數(shù)都是在介紹ONNX模型轉(zhuǎn)換中碰到的坑點(diǎn)以及解決辦法。很少有文章可以系統(tǒng)的介紹ONNX的背景,分析ONNX格式,ONNX簡(jiǎn)化方法等。所以,綜合了相當(dāng)多資料之后我準(zhǔn)備寫一篇ONNX相關(guān)的文章,希望對(duì)大家有用。
0x1. 什么是ONNX?
簡(jiǎn)單描述一下官方介紹,開放神經(jīng)網(wǎng)絡(luò)交換(Open Neural Network Exchange)簡(jiǎn)稱ONNX是微軟和Facebook提出用來(lái)表示深度學(xué)習(xí)模型的開放格式。所謂開放就是ONNX定義了一組和環(huán)境,平臺(tái)均無(wú)關(guān)的標(biāo)準(zhǔn)格式,來(lái)增強(qiáng)各種AI模型的可交互性。
換句話說(shuō),無(wú)論你使用何種訓(xùn)練框架訓(xùn)練模型(比如TensorFlow/Pytorch/OneFlow/Paddle),在訓(xùn)練完畢后你都可以將這些框架的模型統(tǒng)一轉(zhuǎn)換為ONNX這種統(tǒng)一的格式進(jìn)行存儲(chǔ)。注意ONNX文件不僅僅存儲(chǔ)了神經(jīng)網(wǎng)絡(luò)模型的權(quán)重,同時(shí)也存儲(chǔ)了模型的結(jié)構(gòu)信息以及網(wǎng)絡(luò)中每一層的輸入輸出和一些其它的輔助信息。我們直接從onnx的官方模型倉(cāng)庫(kù)拉一個(gè)yolov3-tiny的onnx模型(地址為:https://github.com/onnx/models/tree/master/vision/object_detection_segmentation/tiny-yolov3/model)用Netron可視化一下看看ONNX模型長(zhǎng)什么樣子。

這里我們可以看到ONNX的版本信息,這個(gè)ONNX模型是由Keras導(dǎo)出來(lái)的,以及模型的輸入輸出等信息,如果你對(duì)模型的輸入輸出有疑問(wèn)可以直接看:https://github.com/onnx/models/blob/master/vision/object_detection_segmentation/tiny-yolov3/README.md。
在獲得ONNX模型之后,模型部署人員自然就可以將這個(gè)模型部署到兼容ONNX的運(yùn)行環(huán)境中去。這里一般還會(huì)設(shè)計(jì)到額外的模型轉(zhuǎn)換工作,典型的比如在Android端利用NCNN部署ONNX格式模型,那么就需要將ONNX利用NCNN的轉(zhuǎn)換工具轉(zhuǎn)換到NCNN所支持的bin和param格式。
但在實(shí)際使用ONNX的過(guò)程中,大多數(shù)人對(duì)ONNX了解得并不多,僅僅認(rèn)為它只是一個(gè)完成模型轉(zhuǎn)換和部署工具人而已,我們可以利用它完成模型轉(zhuǎn)換和部署。正是因?yàn)閷?duì)ONNX的不了解,在模型轉(zhuǎn)換過(guò)程中出現(xiàn)的各種不兼容或者不支持讓很多人浪費(fèi)了大量時(shí)間。這篇文章將從理論和實(shí)踐2個(gè)方面談一談ONNX。
0x2. ProtoBuf簡(jiǎn)介
在分析ONNX組織格式前我們需要了解Protobuf, 如果你比較了解Protobuf可以略過(guò)此節(jié)。ONNX作為一個(gè)文件格式,我們自然需要一定的規(guī)則去讀取我們想要的信息或者是寫入我們需要保存信息。ONNX使用的是Protobuf這個(gè)序列化數(shù)據(jù)結(jié)構(gòu)去存儲(chǔ)神經(jīng)網(wǎng)絡(luò)的權(quán)重信息。熟悉Caffe或者Caffe2的同學(xué)應(yīng)該知道,它們的模型存儲(chǔ)數(shù)據(jù)結(jié)構(gòu)協(xié)議也是Protobuf。這個(gè)從安裝ONNX包的時(shí)候也可以看到:

Protobuf是一種輕便高效的結(jié)構(gòu)化數(shù)據(jù)存儲(chǔ)格式,可以用于結(jié)構(gòu)化數(shù)據(jù)串行化,或者說(shuō)序列化。它很適合做數(shù)據(jù)存儲(chǔ)或數(shù)據(jù)交換格式。可用于通訊協(xié)議、數(shù)據(jù)存儲(chǔ)等領(lǐng)域的語(yǔ)言無(wú)關(guān)、平臺(tái)無(wú)關(guān)、可擴(kuò)展的序列化結(jié)構(gòu)數(shù)據(jù)格式。目前提供了 C++、Java、Python 三種語(yǔ)言的 API(摘自官方介紹)。
Protobuf協(xié)議是一個(gè)以*.proto后綴文件為基礎(chǔ)的,這個(gè)文件描述了用戶自定義的數(shù)據(jù)結(jié)構(gòu)。如果需要了解更多細(xì)節(jié)請(qǐng)參考0x7節(jié)的資料3,這里只是想表達(dá)ONNX是基于Protobuf來(lái)做數(shù)據(jù)存儲(chǔ)和傳輸,那么自然onnx.proto就是ONNX格式文件了,接下來(lái)我們就分析一下ONNX格式。
0x3. ONNX格式分析
這一節(jié)我們來(lái)分析一下ONNX的組織格式,上面提到ONNX中最核心的部分就是onnx.proto(https://github.com/onnx/onnx/blob/master/onnx/onnx.proto)這個(gè)文件了,它定義了ONNX這個(gè)數(shù)據(jù)協(xié)議的規(guī)則和一些其它信息?,F(xiàn)在是2021年1月,這個(gè)文件有700多行,我們沒有必要把這個(gè)文件里面的每一行都貼出來(lái),我們只要搞清楚里面的核心部分即可。在這個(gè)文件里面以message關(guān)鍵字開頭的對(duì)象是我們需要關(guān)心的。我們列一下最核心的幾個(gè)對(duì)象并解釋一下它們之間的關(guān)系。
ModelProtoGraphProtoNodeProtoValueInfoProtoTensorProtoAttributeProto
當(dāng)我們加載了一個(gè)ONNX之后,我們獲得的就是一個(gè)ModelProto,它包含了一些版本信息,生產(chǎn)者信息和一個(gè)GraphProto。在GraphProto里面又包含了四個(gè)repeated數(shù)組,它們分別是node(NodeProto類型),input(ValueInfoProto類型),output(ValueInfoProto類型)和initializer(TensorProto類型),其中node中存放了模型中所有的計(jì)算節(jié)點(diǎn),input存放了模型的輸入節(jié)點(diǎn),output存放了模型中所有的輸出節(jié)點(diǎn),initializer存放了模型的所有權(quán)重參數(shù)。
我們知道要完整的表達(dá)一個(gè)神經(jīng)網(wǎng)絡(luò),不僅僅要知道網(wǎng)絡(luò)的各個(gè)節(jié)點(diǎn)信息,還要知道它們的拓?fù)潢P(guān)系。這個(gè)拓?fù)潢P(guān)系在ONNX中是如何表示的呢?ONNX的每個(gè)計(jì)算節(jié)點(diǎn)都會(huì)有input和output兩個(gè)數(shù)組,這兩個(gè)數(shù)組是string類型,通過(guò)input和output的指向關(guān)系,我們就可以利用上述信息快速構(gòu)建出一個(gè)深度學(xué)習(xí)模型的拓?fù)鋱D。這里要注意一下,GraphProto中的input數(shù)組不僅包含我們一般理解中的圖片輸入的那個(gè)節(jié)點(diǎn),還包含了模型中所有的權(quán)重。例如,Conv層里面的W權(quán)重實(shí)體是保存在initializer中的,那么相應(yīng)的會(huì)有一個(gè)同名的輸入在input中,其背后的邏輯應(yīng)該是把權(quán)重也看成模型的輸入,并通過(guò)initializer中的權(quán)重實(shí)體來(lái)對(duì)這個(gè)輸入做初始化,即一個(gè)賦值的過(guò)程。
最后,每個(gè)計(jì)算節(jié)點(diǎn)中還包含了一個(gè)AttributeProto數(shù)組,用來(lái)描述該節(jié)點(diǎn)的屬性,比如Conv節(jié)點(diǎn)或者說(shuō)卷積層的屬性包含group,pad,strides等等,每一個(gè)計(jì)算節(jié)點(diǎn)的屬性,輸入輸出信息都詳細(xì)記錄在https://github.com/onnx/onnx/blob/master/docs/Operators.md。
0x4. onnx.helper
現(xiàn)在我們知道ONNX是把一個(gè)網(wǎng)絡(luò)的每一層或者說(shuō)一個(gè)算子當(dāng)成節(jié)點(diǎn)node,使用這些Node去構(gòu)建一個(gè)Graph,即一個(gè)網(wǎng)絡(luò)。最后將Graph和其它的生產(chǎn)者信息,版本信息等合并在一起生成一個(gè)Model,也即是最終的ONNX模型文件。在構(gòu)建ONNX模型的時(shí)候,https://github.com/onnx/onnx/blob/master/onnx/helper.py這個(gè)文件非常重要,我們可以利用它提供的make_node,make_graph,make_tensor等等接口完成一個(gè)ONNX模型的構(gòu)建,一個(gè)示例如下:
import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
# The protobuf definition can be found here:
# https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
# Create one input (ValueInfoProto)
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [3, 2])
pads = helper.make_tensor_value_info('pads', TensorProto.FLOAT, [1, 4])
value = helper.make_tensor_value_info('value', AttributeProto.FLOAT, [1])
# Create one output (ValueInfoProto)
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [3, 4])
# Create a node (NodeProto) - This is based on Pad-11
node_def = helper.make_node(
'Pad', # node name
['X', 'pads', 'value'], # inputs
['Y'], # outputs
mode='constant', # attributes
)
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[node_def],
'test-model',
[X, pads, value],
[Y],
)
# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example')
print('The model is:\n{}'.format(model_def))
onnx.checker.check_model(model_def)
print('The model is checked!')
這個(gè)官方示例為我們演示了如何使用onnx.helper的make_tensor,make_tensor_value_info,make_attribute,make_node,make_graph,make_node等方法來(lái)完整構(gòu)建了一個(gè)ONNX模型。需要注意的是在上面的例子中,輸入數(shù)據(jù)是一個(gè)一維Tensor,初始維度為[2],這也是為什么經(jīng)過(guò)維度為[1,4]的Pad操作之后獲得的輸出Tensor維度為[3,4]。另外由于Pad操作是沒有帶任何權(quán)重信息的,所以當(dāng)你打印ONNX模型時(shí),ModelProto的GraphProto是沒有initializer這個(gè)屬性的。
0x5. onnx-simplifier
原本這里是要總結(jié)一些使用ONNX進(jìn)行模型部署經(jīng)常碰到一些因?yàn)榘姹炯嫒菪?,或者各種框架OP沒有對(duì)齊等原因?qū)е碌母鞣NBUG。但是這樣會(huì)顯得文章很長(zhǎng),所以這里以一個(gè)經(jīng)典的Pytorch轉(zhuǎn)ONNX的reshape問(wèn)題為例子,來(lái)嘗試講解一下大老師的onnx-simplifier是怎么處理的,個(gè)人認(rèn)為這個(gè)問(wèn)題是基于ONNX進(jìn)行模型部署最經(jīng)典的問(wèn)題。希望在解決這個(gè)問(wèn)題的過(guò)程中大家能有所收獲。
問(wèn)題發(fā)生在當(dāng)我們想把下面這段代碼導(dǎo)出ONNX模型時(shí):
import torch
class JustReshape(torch.nn.Module):
def __init__(self):
super(JustReshape, self).__init__()
def forward(self, x):
return x.view((x.shape[0], x.shape[1], x.shape[3], x.shape[2]))
net = JustReshape()
model_name = 'just_reshape.onnx'
dummy_input = torch.randn(2, 3, 4, 5)
torch.onnx.export(net, dummy_input, model_name, input_names=['input'], output_names=['output'])
由于這個(gè)模型輸入維度是固定的,所以我們期望模型是這樣的:

但是,即使使用了ONNX的polished工具也只能獲得下面的模型:

要解決這個(gè)問(wèn)題,有兩種方法,第一種是做一個(gè)強(qiáng)制類型轉(zhuǎn)換,將x.shape[0]類似的變量強(qiáng)制轉(zhuǎn)換為常量即int(x.shape[0]),或者使用大老師的onnx-simplifer來(lái)解決這一問(wèn)題。
之前一直好奇onnx-simplifer是怎么做的,最近對(duì)ONNX有了一些理解之后也能逐步看懂做法了。我來(lái)嘗試解釋一下。onnx-simplifer的核心思路就是利用onnxruntime推斷一遍ONNX的計(jì)算圖,然后使用常量輸出替代冗余的運(yùn)算OP。主體代碼為:
def simplify(model: Union[str, onnx.ModelProto], check_n: int = 0, perform_optimization: bool = True,
skip_fuse_bn: bool = False, input_shapes: Optional[TensorShapes] = None, skipped_optimizers: Optional[Sequence[str]] = None, skip_shape_inference=False) \
-> Tuple[onnx.ModelProto, bool]:
if input_shapes is None:
input_shapes = {}
if type(model) == str:
# 加載ONNX模型
model = onnx.load(model)
# 檢查ONNX模型格式是否正確,圖結(jié)構(gòu)是否完整,節(jié)點(diǎn)是否正確等
onnx.checker.check_model(model)
# 深拷貝一份原始ONNX模型
model_ori = copy.deepcopy(model)
if not skip_shape_inference:
# 獲取ONNX模型中特征圖的尺寸
model = infer_shapes(model)
input_shapes = check_and_update_input_shapes(model, input_shapes)
if perform_optimization:
model = optimize(model, skip_fuse_bn, skipped_optimizers)
const_nodes = get_constant_nodes(model)
res = forward_for_node_outputs(
model, const_nodes, input_shapes=input_shapes)
const_nodes = clean_constant_nodes(const_nodes, res)
model = eliminate_const_nodes(model, const_nodes, res)
onnx.checker.check_model(model)
if not skip_shape_inference:
model = infer_shapes(model)
if perform_optimization:
model = optimize(model, skip_fuse_bn, skipped_optimizers)
check_ok = check(model_ori, model, check_n, input_shapes=input_shapes)
return model, check_ok
上面有一行:model = infer_shapes(model) 是獲取ONNX模型中特征圖的尺寸,它的具體實(shí)現(xiàn)如下:
def infer_shapes(model: onnx.ModelProto) -> onnx.ModelProto:
try:
model = onnx.shape_inference.infer_shapes(model)
except:
pass
return model
我們保存一下調(diào)用了這個(gè)接口之后的ONNX模型,并將其可視化看一下:

相對(duì)于原始的ONNX模型,現(xiàn)在每一條線都新增了一個(gè)shape信息,代表它的前一個(gè)特征圖的shape是怎樣的。
接著,程序使用到了check_and_update_input_shapes接口,這個(gè)接口的代碼示例如下,它可以用來(lái)判斷輸入的格式是否正確以及輸入模型是否存在所有的指定輸入節(jié)點(diǎn)。
def check_and_update_input_shapes(model: onnx.ModelProto, input_shapes: TensorShapes) -> TensorShapes:
input_names = get_input_names(model)
if None in input_shapes:
if len(input_names) == 1:
input_shapes[input_names[0]] = input_shapes[None]
del input_shapes[None]
else:
raise RuntimeError(
'The model has more than 1 inputs, please use the format "input_name:dim0,dim1,...,dimN" in --input-shape')
for x in input_shapes:
if x not in input_names:
raise RuntimeError(
'The model doesn\'t have input named "{}"'.format(x))
return input_shapes
在這個(gè)例子中,如果我們指定input_shapes為:{'input': [2, 3, 4, 5]},那么這個(gè)函數(shù)的輸出也為{'input': [2, 3, 4, 5]}。如果不指定,輸出就是{}。驗(yàn)證這個(gè)函數(shù)的調(diào)用代碼如下所示:

確定了輸入沒有問(wèn)題之后,程序會(huì)根據(jù)用戶指定是否優(yōu)化ONNX模型進(jìn)入優(yōu)化函數(shù),函數(shù)定義如下:
def optimize(model: onnx.ModelProto, skip_fuse_bn: bool, skipped_optimizers: Optional[Sequence[str]]) -> onnx.ModelProto:
"""
:model參數(shù): 待優(yōu)化的ONXX模型.
:return: 優(yōu)化之后的ONNX模型.
簡(jiǎn)化之前, 使用這個(gè)方法產(chǎn)生會(huì)在'forward_all'用到的ValueInfo
簡(jiǎn)化之后,使用這個(gè)方法去折疊前一步產(chǎn)生的常量到initializer中并且消除沒被使用的常量
"""
onnx.checker.check_model(model)
onnx.helper.strip_doc_string(model)
optimizers_list = [
'eliminate_deadend',
'eliminate_nop_dropout',
'eliminate_nop_cast',
'eliminate_nop_monotone_argmax', 'eliminate_nop_pad',
'extract_constant_to_initializer', 'eliminate_unused_initializer',
'eliminate_nop_transpose',
'eliminate_nop_flatten', 'eliminate_identity',
'fuse_add_bias_into_conv',
'fuse_consecutive_concats',
'fuse_consecutive_log_softmax',
'fuse_consecutive_reduce_unsqueeze', 'fuse_consecutive_squeezes',
'fuse_consecutive_transposes', 'fuse_matmul_add_bias_into_gemm',
'fuse_pad_into_conv', 'fuse_transpose_into_gemm', 'eliminate_duplicate_initializer'
]
if not skip_fuse_bn:
optimizers_list.append('fuse_bn_into_conv')
if skipped_optimizers is not None:
for opt in skipped_optimizers:
try:
optimizers_list.remove(opt)
except ValueError:
pass
model = onnxoptimizer.optimize(model, optimizers_list,
fixed_point=True)
onnx.checker.check_model(model)
return model
這個(gè)函數(shù)的功能是對(duì)原始的ONNX模型做一些圖優(yōu)化工作,比如merge_bn,fuse_add_bias_into_conv等等。我們使用onnx.save保存一下這個(gè)例子中圖優(yōu)化后的模型,可以發(fā)現(xiàn)它和優(yōu)化前的可視化效果是一樣的,如下圖所示:

這是因?yàn)樵谶@個(gè)模型中是沒有上面列舉到的那些可以做圖優(yōu)化的情況,但是當(dāng)我們打印一下ONNX模型我們會(huì)發(fā)現(xiàn)optimize過(guò)后的ONNX模型多出一些initializer數(shù)組:

這些數(shù)組存儲(chǔ)的就是這個(gè)圖中那些常量OP的具體值,通過(guò)這個(gè)處理我們就可以調(diào)用get_constant_nodes函數(shù)來(lái)獲取ONNX模型的常量OP了,這個(gè)函數(shù)的詳細(xì)解釋如下:
def get_constant_nodes(m: onnx.ModelProto) -> List[onnx.NodeProto]:
const_nodes = []
# 如果節(jié)點(diǎn)的name在ONNX的GraphProto的initizlizer數(shù)組里面,它就是靜態(tài)的tensor
const_tensors = [x.name for x in m.graph.initializer]
# 顯示的常量OP也加進(jìn)來(lái)
const_tensors.extend([node.output[0]
for node in m.graph.node if node.op_type == 'Constant'])
# 一些節(jié)點(diǎn)的輸出shape是由輸入節(jié)點(diǎn)決定的,我們認(rèn)為這個(gè)節(jié)點(diǎn)的輸出shape并不是常量,
# 所以我們不需要簡(jiǎn)化這種節(jié)點(diǎn)
dynamic_tensors = []
# 判斷是否為動(dòng)態(tài)OP
def is_dynamic(node):
if node.op_type in ['NonMaxSuppression', 'NonZero', 'Unique'] and node.input[0] not in const_tensors:
return True
if node.op_type in ['Reshape', 'Expand', 'Upsample', 'ConstantOfShape'] and len(node.input) > 1 and node.input[1] not in const_tensors:
return True
if node.op_type in ['Resize'] and ((len(node.input) > 2 and node.input[2] not in const_tensors) or (len(node.input) > 3 and node.input[3] not in const_tensors)):
return True
return False
for node in m.graph.node:
if any(x in dynamic_tensors for x in node.input):
dynamic_tensors.extend(node.output)
elif node.op_type == 'Shape':
const_nodes.append(node)
const_tensors.extend(node.output)
elif is_dynamic(node):
dynamic_tensors.extend(node.output)
elif all([x in const_tensors for x in node.input]):
const_nodes.append(node)
const_tensors.extend(node.output)
# 深拷貝
return copy.deepcopy(const_nodes)
在這個(gè)例子中,我們打印一下執(zhí)行這個(gè)獲取常量OP函數(shù)之后,Graph中有哪些OP被看成了常量OP。

獲取了模型中所有的常量OP之后,我們需要把所有的靜態(tài)節(jié)點(diǎn)擴(kuò)展到ONNX Graph的輸出節(jié)點(diǎn)列表中,然后利用onnxruntme執(zhí)行一次forward:
def forward_for_node_outputs(model: onnx.ModelProto, nodes: List[onnx.NodeProto],
input_shapes: Optional[TensorShapes] = None) -> Dict[str, np.ndarray]:
if input_shapes is None:
input_shapes = {}
model = copy.deepcopy(model)
# nodes 是Graph中所有的靜態(tài)OP
add_features_to_output(model, nodes)
res = forward(model, input_shapes=input_shapes)
return res
其中add_features_to_output的定義如下:
def add_features_to_output(m: onnx.ModelProto, nodes: List[onnx.NodeProto]) -> None:
"""
Add features to output in pb, so that ONNX Runtime will output them.
:param m: the model that will be run in ONNX Runtime
:param nodes: nodes whose outputs will be added into the graph outputs
"""
# ONNX模型的graph擴(kuò)展輸出節(jié)點(diǎn),獲取所有靜態(tài)OP的輸出和原始輸出節(jié)點(diǎn)的輸出
for node in nodes:
for output in node.output:
m.graph.output.extend([onnx.ValueInfoProto(name=output)])
最后的forward函數(shù)就是利用onnxruntime推理獲得我們指定的輸出節(jié)點(diǎn)的值。這個(gè)函數(shù)這里不進(jìn)行解釋。推理完成之后,進(jìn)入下一個(gè)函數(shù)clean_constant_nodes,這個(gè)函數(shù)的定義如下:
def clean_constant_nodes(const_nodes: List[onnx.NodeProto], res: Dict[str, np.ndarray]):
"""
It seems not needed since commit 6f2a72, but maybe it still prevents some unknown bug
:param const_nodes: const nodes detected by `get_constant_nodes`
:param res: The dict containing all tensors, got by `forward_all`
:return: The constant nodes which have an output in res
"""
return [node for node in const_nodes if node.output[0] in res]
這個(gè)函數(shù)是用來(lái)清洗那些沒有被onnxruntime推理的靜態(tài)節(jié)點(diǎn),但通過(guò)上面的optimize邏輯,我們的graph中其實(shí)已經(jīng)不存在這個(gè)情況了(沒有被onnxruntime推理的靜態(tài)節(jié)點(diǎn)在圖優(yōu)化階段會(huì)被優(yōu)化掉),因此這個(gè)函數(shù)理論上是可以刪除的。這個(gè)地方是為了避免刪除掉有可能引發(fā)其它問(wèn)題就保留了。
不過(guò)從一些實(shí)際經(jīng)驗(yàn)來(lái)看,還是保留吧,畢竟不能保證ONNX的圖優(yōu)化就完全正確,前段時(shí)間剛發(fā)現(xiàn)了TensorRT圖優(yōu)化出了一個(gè)BUG。保留這個(gè)函數(shù)可以提升一些程序的穩(wěn)定性。

接下來(lái)就是這個(gè)onnx-simplifier最核心的步驟了,即將常量節(jié)點(diǎn)從原始的ONNX Graph中移除,函數(shù)接口為eliminate_const_nodes:
def eliminate_const_nodes(model: onnx.ModelProto, const_nodes: List[onnx.NodeProto],
res: Dict[str, np.ndarray]) -> onnx.ModelProto:
"""
:model參數(shù): 原始ONNX模型
:const_nodes參數(shù): 使用`get_constant_nodes`獲得的靜態(tài)OP
:res參數(shù): 包含所有輸出Tensor的字典
:return: 簡(jiǎn)化后的模型. 所有冗余操作都已刪除.
"""
for i, node in enumerate(model.graph.node):
if node in const_nodes:
for output in node.output:
new_node = copy.deepcopy(node)
new_node.name = "node_" + output
new_node.op_type = 'Constant'
new_attr = onnx.helper.make_attribute(
'value',
onnx.numpy_helper.from_array(res[output], name=output)
)
del new_node.input[:]
del new_node.attribute[:]
del new_node.output[:]
new_node.output.extend([output])
new_node.attribute.extend([new_attr])
insert_elem(model.graph.node, i + 1, new_node)
del model.graph.node[i]
return model
運(yùn)行這個(gè)函數(shù)之后我們獲得的ONNX模型可視化結(jié)果是這樣子的:

注意,這里獲得的ONNX模型中雖然常量節(jié)點(diǎn)已經(jīng)從Graph中斷開了,即相當(dāng)于這個(gè)DAG里面多了一些單獨(dú)的點(diǎn),但是這些點(diǎn)還是存在的。因此,我們?cè)賵?zhí)行一次optimize就可以獲得最終簡(jiǎn)化后的ONNX模型了。最終簡(jiǎn)化后的ONNX模型如下圖所示:

0x6. 總結(jié)
介于篇幅原因,介紹ONNX的第一篇文章就介紹到這里了,后續(xù)可能會(huì)結(jié)合更多實(shí)踐的經(jīng)驗(yàn)來(lái)談?wù)凮NNX了,例如OneFlow模型導(dǎo)出ONNX進(jìn)行部署???傊?,文章很長(zhǎng),謝謝你的觀看,希望這篇文章有幫助到你。最后歡迎star大老師的onnx-simplifier。
0x7. 參考資料
【1】https://zhuanlan.zhihu.com/p/86867138 【2】https://oldpan.me/archives/talk-about-onnx 【3】https://blog.csdn.net/chengzi_comm/article/details/53199278 【4】https://www.jianshu.com/p/a24c88c0526a 【5】https://bindog.github.io/blog/2020/03/13/deep-learning-model-convert-and-depoly/ 【6】 https://github.com/daquexian/onnx-simplifier
歡迎關(guān)注GiantPandaCV, 在這里你將看到獨(dú)家的深度學(xué)習(xí)分享,堅(jiān)持原創(chuàng),每天分享我們學(xué)習(xí)到的新鮮知識(shí)。( ? ?ω?? )?
有對(duì)文章相關(guān)的問(wèn)題,或者想要加入交流群,歡迎添加BBuf微信:
為了方便讀者獲取資料以及我們公眾號(hào)的作者發(fā)布一些Github工程的更新,我們成立了一個(gè)QQ群,二維碼如下,感興趣可以加入。
