實操教程|PyTorch AutoGrad C++層實現(xiàn)

極市導(dǎo)讀
本文為一篇實操教程,作者介紹了PyTorch AutoGrad C++層實現(xiàn)中各個概念的解釋。 >>加入極市CV技術(shù)交流群,走在計算機視覺的最前沿
autograd依賴的數(shù)據(jù)結(jié)構(gòu)
at::Tensor:shared ptr 指向 TensorImpl
TensorImpl:對 at::Tensor 的實現(xiàn)
包含一個類型為 [AutogradMetaInterface](c10::AutogradMetaInterface)的autograd_meta_,在tensor是需要求導(dǎo)的variable時,會被實例化為[AutogradMeta](c10::AutogradMetaInterface),里面包含了autograd需要的信息
Variable: 就是Tensor,為了向前兼容保留的
using Variable = at::Tensor; 概念上有區(qū)別, Variable是需要計算gradient的,Tensor是不需要計算gradient的Variable的AutogradMeta是對[AutogradMetaInterface](c10::AutogradMetaInterface)的實現(xiàn),里面包含了一個Variable,就是該variable的gradient帶有version和view 會實例化 AutogradMeta, autograd需要的關(guān)鍵信息都在這里
AutoGradMeta : 記錄 Variable 的autograd歷史信息
包含一個叫g(shù)rad_的 Variable, 即AutoGradMeta對應(yīng)的var的梯度tensor包含類型為 Node指針的grad_fn(var在graph內(nèi)部時)和grad_accumulator(var時葉子時), 記錄生成grad_的方法包含 output_nr,標(biāo)識var對應(yīng)grad_fn的輸入編號構(gòu)造函數(shù)包含一個類型為 Edge的gradient_edge,gradient_edge.function就是grad_fn, 另外gradient_edge.input_nr記錄著對應(yīng)grad_fn的輸入編號,會賦值給AutoGradMeta的output_nr
autograd::Edge: 指向autograd::Node的一個輸入
包含類型為 Node指針,表示edge指向的Node包含 input_nr, 表示edge指向的Node的輸入編號
autograd::Node: 對應(yīng)AutoGrad Graph中的Op
是所有autograd op的抽象基類,子類重載apply方法
next_edges_記錄出邊input_metadata_記錄輸入的tensor的metadata實現(xiàn)的子類一般是可求導(dǎo)的函數(shù)和他們的梯度計算op
Node in AutoGrad Graph
Variable通過Edge關(guān)聯(lián)Node的輸入和輸出 多個Edge指向同一個Var時,默認(rèn)做累加 call operator
最重要的方法,實現(xiàn)計算 next_edge
縫合Node的操作 獲取Node的出邊,next_edge(index)/next_edges() add_next_edge(),創(chuàng)建
前向計算
PyTorch通過tracing只生成了后向AutoGrad Graph.
代碼是生成的,需要編譯才能看到對應(yīng)的生成結(jié)果
gen_variable_type.py生成可導(dǎo)版本的op 生成的代碼在 pytorch/torch/csrc/autograd/generated/前向計算時,進(jìn)行了tracing,記錄了后向計算圖構(gòu)建需要的信息 這里以relu為例,代碼在 pytorch/torch/csrc/autograd/generated/VariableType_0.cpp
Tensor relu(const Tensor & self) {auto& self_ = unpack(self, "self", 0);std::shared_ptr<ReluBackward0> grad_fn;if (compute_requires_grad( self )) { // 如果輸入var需要grad// ReluBackward0的類型是Nodegrad_fn = std::shared_ptr<ReluBackward0>(new ReluBackward0(), deleteNode);// collect_next_edges(var)返回輸入var對應(yīng)的指向的// grad_fn(前一個op的backward或者是一個accumulator的)的輸入的Edge// set_next_edges(),在grad_fn中記錄這些Edge(這里完成了后向的構(gòu)圖)grad_fn->set_next_edges(collect_next_edges( self ));// 記錄當(dāng)前var的一個版本grad_fn->self_ = SavedVariable(self, false);}c10::optional<Storage> self__storage_saved =self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;c10::intrusive_ptr<TensorImpl> self__impl_saved;if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();auto tmp = ([&]() {at::AutoNonVariableTypeMode non_var_type_mode(true);return at::relu(self_); // 前向計算})();auto result = std::move(tmp);if (self__storage_saved.has_value())AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());if (grad_fn) {// grad_fn增加一個輸入,記錄輸出var的metadata作為grad_fn的輸入// 輸出var的AutoGradMeta實例化,輸出var的AutoGradMeta指向起grad_fn的輸入set_history(flatten_tensor_args( result ), grad_fn);}return result;}
可以看到和 grad_fn相關(guān)的操作trace了一個op的計算,構(gòu)建了后向計算圖.
后向計算
autograd::backward():計算output var的梯度值,調(diào)用的 run_backward()
autograd::grad() :計算有output var和到特定input的梯度值,調(diào)用的 run_backward()
autograd::run_backward()
對于要求梯度的output var,獲取其指向的grad_fn作為roots,是后向圖的起點 對于有input var的,獲取其指向的grad_fn作為output_edges, 是后向圖的終點 調(diào)用 autograd::Engine::get_default_engine().execute(...)執(zhí)行后向計算
autograd::Engine::execute(...)
創(chuàng)建
GraphTask,記錄了一些配置信息創(chuàng)建
GraphRoot,是一個Node,把所有的roots作為其輸出邊,Node的apply()返回的是roots的grad【這里已經(jīng)得到一個單起點的圖】計算依賴
compute_dependencies(...)從GraphRoot開始,廣度遍歷,記錄所有碰到的grad_fn的指針,并統(tǒng)計grad_fn被遇到的次數(shù),這些信息記錄到GraphTask中 GraphTask初始化:當(dāng)有input var時,判斷后向圖中哪些節(jié)點是真正需要計算的GraphTask執(zhí)行選擇CPU or GPU線程執(zhí)行 以CPU為例,調(diào)用的 autograd::Engine::thread_main(...)
autograd::Engine::thread_main(...)
evaluate_function(...),輸入輸出的處理,調(diào)度call_function(...), 調(diào)用對應(yīng)的Node計算執(zhí)行后向過程中的生成的中間grad Tensor,如果不釋放,可以用于計算高階導(dǎo)數(shù);(同構(gòu)的后向圖,之前的grad tensor是新的輸出,grad_fn變成之前grad_fn的backward,這些新的輸出還可以再backward) 具體的執(zhí)行機制可以支撐單獨開一個Topic分析,在這里討論到后向圖完成構(gòu)建為止.
推薦閱讀
2021-04-11
2021-04-08
2021-04-07

# CV技術(shù)社群邀請函 #
備注:姓名-學(xué)校/公司-研究方向-城市(如:小極-北大-目標(biāo)檢測-深圳)
即可申請加入極市目標(biāo)檢測/圖像分割/工業(yè)檢測/人臉/醫(yī)學(xué)影像/3D/SLAM/自動駕駛/超分辨率/姿態(tài)估計/ReID/GAN/圖像增強/OCR/視頻理解等技術(shù)交流群
每月大咖直播分享、真實項目需求對接、求職內(nèi)推、算法競賽、干貨資訊匯總、與 10000+來自港科大、北大、清華、中科院、CMU、騰訊、百度等名校名企視覺開發(fā)者互動交流~

