1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        PyTorch 源碼解讀之即時(shí)編譯篇

        共 84968字,需瀏覽 170分鐘

         ·

        2021-06-13 21:57


        作者丨OpenMMLab
        來源丨h(huán)ttps://zhuanlan.zhihu.com/p/361101354
        編輯丨GiantPandaCV

        前言

        torch 從 1.0 開始支持了 jit 模塊,其大概包括以下幾個(gè)部分:

        • 一種新的計(jì)算圖中間表示 (Intermediate Representation),之后簡(jiǎn)稱為 IR.
        • 從 Python 代碼導(dǎo)出IR的兩種方法,即 trace 與 script.
        • IR 優(yōu)化以及 IR 的解釋器(翻譯為具體的運(yùn)算 op).

        這篇解讀會(huì)分為以下幾個(gè)部分:

        • jit 的簡(jiǎn)單介紹以及兩種導(dǎo)出方式的使用例子
        • jit 中 IR 的形式
        • 導(dǎo)出 IR 的兩種方式,trace 與 script 的源碼解讀
        • IR 優(yōu)化的簡(jiǎn)單介紹

        1 jit 的簡(jiǎn)單介紹以及使用例子

        JIT 簡(jiǎn)介

        如前言,這篇解讀雖然標(biāo)題是 JIT,但是真正稱得上即時(shí)編譯器的部分是在導(dǎo)出 IR 后,即優(yōu)化 IR 計(jì)算圖,并且解釋為對(duì)應(yīng) operation 的過程,即PyTorch jit 相關(guān) code 帶來的優(yōu)化一般是計(jì)算圖級(jí)別優(yōu)化,比如部分運(yùn)算的融合,但是對(duì)具體算子(如卷積)是沒有特定優(yōu)化的,其依舊調(diào)用 torch的基礎(chǔ)算子庫(kù).

        大家也可以在導(dǎo)出 IR 也就是 torchscript 后,使用其他的編譯優(yōu)化或者解釋器,如現(xiàn)在也有script to a TensorRT engine,TRTtorch轉(zhuǎn) tensorRT 的方案。

        trace

        給大家一個(gè)簡(jiǎn)單例子。

        import torchvision.models as models
            resnet = torch.jit.trace(models.resnet18(),torch.rand(1,3,224,224))
            output=resnet(torch.ones(1,3,224,224))
            print(output)
            output=resnet(torch.ones(1,3,224,224))
            resnet.save('resnet.pt')

        output 便是我們導(dǎo)出的中間表示,其可以 save 下來,在其他框架使用

        我們可以看下 output 中的 IR,即 torchscript 表征的計(jì)算圖是什么樣子的。

        graph(%self.1 : __torch__.torchvision.models.resnet.___torch_mangle_194.ResNet,
            %input.1 : Float(1:1505283:50176224:224224:1, requires_grad=0, device=cpu)):
            %1472 : __torch__.torch.nn.modules.linear.___torch_mangle_193.Linear = prim::GetAttr[name="fc"](%self.1)
            %1469 : __torch__.torch.nn.modules.pooling.___torch_mangle_192.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)
            %1468 : __torch__.torch.nn.modulesjieshao.container.___torch_mangle_191.Sequential = prim::GetAttr[name="layer4"](%self.1)
            %1422 : __torch__.torch.nn.modules.container.___torch_mangle_175.Sequential = prim::GetAttr[name="layer3"](%self.1)
            ....
            %1556 : Tensor = prim::CallMethod[name="forward"](%1469, %1555)
            %1202 : int = prim::Constant[value=1]()
            %1203 : int = prim::Constant[value=-1]()
            %input : Float(1:512512:1, requires_grad=1, device=cpu) = aten::flatten(%1556, %1202, %1203
            %1557 : Tensor = prim::CallMethod[name="forward"](%1472, %input)
            return (%1557)

        這便是 trace 方法的使用,其核心實(shí)現(xiàn)的入口便是torch.jit.trace,參數(shù)為你需要導(dǎo)出的 model,以及合法輸入input,其大概原理恰如其名,便是跟蹤模型 inference 過程,將模型對(duì)輸入進(jìn)行的操作逐一記錄下來,并對(duì)應(yīng)到 IR 的操作,從而得到原本模型forward 的 IR。

        ote :但是這種實(shí)現(xiàn)方式有很明顯的缺陷,PyTorch 作為動(dòng)態(tài)圖網(wǎng)絡(luò),會(huì)有很多的 input dependent的控制流語(yǔ)句,根據(jù)輸入的不同可能會(huì)執(zhí)行情況會(huì)不同(if 或者 變長(zhǎng)的 loop),這樣就無法 trace 到完整的計(jì)算圖。如下就是一個(gè) trace

        失敗的 case:

        if x > 2.0:
            r = torch.tensor(1.0)
            else:
             r = torch.tensor(2.0)
            return r
            
        ftrace = torch.jit.trace(test, (torch.ones(1)))
        y = torch.ones(1) * 5
        print(ftrace(y))
        # results: tensor(2.)
        # 因?yàn)檩斎胫蛔吡说姆种lse

        script

        @torch.jit.script
        def foo(x, y):
            if x.max() > y.max():
                r = x
            else:
                r = y
            return r
            
        print(foo.graph)
            
        print(foo(torch.Tensor([0]), torch.Tensor([1])))
        print(foo(torch.Tensor([1]), torch.Tensor([0])))
            
        graph(%x.1 : Tensor,
              %y.1 : Tensor):
          %3 : Tensor = aten::max(%x.1
          %5 : Tensor = aten::max(%y.1
          # 可以看到確實(shí)捕捉到了控制語(yǔ)句,
          %6 : Tensor = aten::gt(%3, %5
          %7 : bool = aten::Bool(%6
          %r : Tensor = prim::If(%7
            block0():
              -> (%x.1)
            block1():
              -> (%y.1)
          return (%r)
            
        tensor([1.])
        tensor([1.])

        script 使用是在你需要的地方 (fuction or nn.Module (默認(rèn)追蹤 forward函數(shù)))掛載裝飾器torch.jit.script,其轉(zhuǎn)換方式跟 trace 是完全不同的思路,script 直接解析你的 PyTorch代碼,通過語(yǔ)法分析解析你的邏輯為一棵語(yǔ)法樹,然后轉(zhuǎn)換為中間表示 IR。

        Note: 雖然其可以解決 trace 存在無法追蹤動(dòng)態(tài)邏輯的問題,但是 Python 作為靈活度極高的語(yǔ)法, 想完整支持解析各種 Python 操作幾乎是不可能的,因此我們需要額外的時(shí)間熟悉哪些寫法是可以被解析的,讓我們寫代碼的體驗(yàn)大打折扣。

        兩者結(jié)合

        兩者各有優(yōu)勢(shì),支持靈活集合。

        import torch
        import torch.nn as nn
        import torch.nn.functional as F
            
        class MyModule(nn.Module):
            def __init__(self):
                super(MyModule, self).__init__()
                # torch.jit.trace produces a ScriptModule's conv1 and conv2
                self.conv1 = torch.jit.trace(nn.Conv2d(1205), torch.rand(111616))
                self.conv2 = torch.jit.trace(nn.Conv2d(20205), torch.rand(1201616))
            
            def forward(self, input):
                input = F.relu(self.conv1(input))
                input = F.relu(self.conv2(input))
                return input
            
        scripted_module = torch.jit.script(MyModule())

        因此實(shí)際使用時(shí)候,可以有如下準(zhǔn)則:

        1 大部分情況 model 只有 tensor operation,就直接無腦 tracing

        2 帶 control-flow (if-else, for-loop) 的,上 scripting

        3 碰上 scripting 不能 handle 的語(yǔ)法,要么重寫,要么把 tracing 和 scripting 合起來用(比如說只在有 control-

        flow 的代碼用 scripting,其他用 tracing)

        如何擴(kuò)展

        trace 與 script 都不能轉(zhuǎn)換第三方 Python 庫(kù)中的函數(shù),盡量所有代碼都使用 PyTorch 實(shí)現(xiàn), 自定義 op 需要注冊(cè)成 jit

        操作( torch 的 op 其實(shí)也注冊(cè)了),最后轉(zhuǎn)成 torchscript。

            TORCH_LIBRARY(my_ops, m) {
              m.def("warp_perspective", warp_perspective);
            }

        更多可以參考官方教程

        1 EXTENDING TORCHSCRIPT WITH CUSTOM C++ OPERATORS

        2 IR (torchscript)的基本表示

        PyTorch 中的各種設(shè)計(jì)(parameter,計(jì)算節(jié)點(diǎn)等)在 torchscript 中是如何對(duì)應(yīng)的呢?

        這便是轉(zhuǎn)換出的 IR 結(jié)果,torchscrip 以下結(jié)構(gòu)組合。

        名稱source code簡(jiǎn)介
        Modulesmodule.h對(duì)標(biāo) nn.Module
        Parametersmodule.h對(duì)標(biāo) PyTorch 的 parameter
        MethodMethod.h包括 FunctionSchema 方法描述,Graph 實(shí)際計(jì)算圖,GraphExecutor do the optimization and execution
        FunctionSchemafunction_schema.h描述參數(shù)與返回類型
        Graphir.h定義 function 的具體實(shí)現(xiàn),包括 Nodes,Blocks,Values
        Nodesir.h一個(gè)指令,如一次卷積運(yùn)算,一次矩陣運(yùn)算
        Blockir.h控制語(yǔ)句 if,loop + list of nodes

        還有with,Value,Type

            # %x.1 value
            graph(%x.1 : Tensor,
                  %y.1 : Tensor):
                  # aten::max 就是一個(gè)Node
                  # Tensor: Type-TensorType
              %3 : Tensor = aten::max(%x.1
              %5 : Tensor = aten::max(%y.1
              %6 : Tensor = aten::gt(%3, %5
              %7 : bool = aten::Bool(%6
              %r : Tensor = prim::If(%7
               # Blocks 
                block0():
                  -> (%x.1)
                block1():
                  -> (%y.1)
              return (%r)

        3 導(dǎo)出 IR 的兩種方式,trace 與 script

        因?yàn)槠渚唧w實(shí)現(xiàn)頗為復(fù)雜,粘貼的源碼也僅僅保留了簡(jiǎn)單 case 跑過的分支,并且省去了絕大部分細(xì)節(jié),讀者如有需要更多細(xì)節(jié)可以自行去源碼查閱。

        trace 實(shí)現(xiàn)

            func,
                example_inputs,
                optimize=None,
                check_trace=True,
                check_inputs=None,
                check_tolerance=1e-5,
                strict=True,
                _force_outplace=False,
                _module_class=None,
                _compilation_unit=_python_cu,
            ):


                # 發(fā)現(xiàn)是nn.Module instacene forward, 追蹤forward
                if isinstance(func, torch.nn.Module):
                    return trace_module(
                        func,
                        {"forward": example_inputs},
                        None,
                        check_trace,
                        wrap_check_inputs(check_inputs),
                        check_tolerance,
                        strict,
                        _force_outplace,
                        _module_class,
                    )
                # 傳進(jìn)來的是某個(gè)module instance的forward
                if (
                    hasattr(func, "__self__")
                    and isinstance(func.__self__, torch.nn.Module)
                    and func.__name__ == "forward"
                ):
                    return trace_module(
                        func.__self__,
                        {"forward": example_inputs},
                        None,
                        check_trace,
                        wrap_check_inputs(check_inputs),
                        check_tolerance,
                        strict,
                        _force_outplace,
                        _module_class,
                    )
                # 一個(gè)查找變量名的接口
                var_lookup_fn = _create_interpreter_name_lookup_fn(0)
            
               # C++ 入口 
               traced = torch._C._create_function_from_trace(
                   name, func, example_inputs, var_lookup_fn, strict,_force_outplace
                )
            
                # 檢查traced 與 原func是否有差異
                if check_trace:
                    if check_inputs is not None:
                        _check_trace(
                            check_inputs,
                            func,
                            traced,
                            check_tolerance,
                            strict,
                            _force_outplace,
                            False,
                            _module_class,
                        )
                    else:
                        _check_trace(
                            [example_inputs],
                            func,
                            traced,
                            check_tolerance,
                            strict,
                            _force_outplace,
                            False,
                            _module_class,
                        )
            
                return traced

        我們發(fā)現(xiàn)經(jīng)過簡(jiǎn)單的判斷,代碼便進(jìn)入了 C++ 相關(guān)函數(shù)

            traced = torch._C._create_function_from_trace(
                    name, func, example_inputs, var_lookup_fn, strict, _force_outplace
                )

        我們?nèi)?C++ 中看下發(fā)生了什么

            std::pair<std::shared_ptr<TracingState>, Stack> trace(
                Stack inputs,
                const std::function<Stack(Stack)>& traced_fn,
                std::function<std::string(const Variable&)> var_name_lookup_fn,
                bool strict,
                bool force_outplace,
                Module* self)
         
        {
              try {
            
                auto state = std::make_shared<TracingState>();
                # setTracingState 將state 這個(gè)實(shí)例set下來,在之后計(jì)算節(jié)點(diǎn)get出來insert計(jì)算過程
                setTracingState(state);
            
                #state這個(gè)數(shù)據(jù)結(jié)構(gòu)會(huì)在forward過程中存儲(chǔ)trace到的計(jì)算過程
                if (self) {
                  Value* self_value = state->graph->insertInput(0"self")->setType(
                      self->_ivalue()->type());
                  gatherParametersAndBuffers(state, self_value, *self, {"__module"});
                }
            
                for (IValue& input : inputs) {
                  input = addInput(state, input, input.type(), state->graph->addInput());
                }
                auto graph = state->graph;
                # 將python中的變量名解析函數(shù)綁定下來
                getTracingState()->lookup_var_name_fn = std::move(var_name_lookup_fn);
                getTracingState()->strict = strict;
                getTracingState()->force_outplace = force_outplace;
            
                # 開始forward,在計(jì)算發(fā)生時(shí),會(huì)把計(jì)算記錄到state中
                auto out_stack = traced_fn(inputs);
            
                // Exit a trace, treating 'out_stack' as the outputs of the trace.  These
                // are the variables whose values will be computed upon subsequent
                // invocations of the trace.
                size_t i = 0;
                for (auto& output : out_stack) {
                  // NB: The stack is in "reverse" order, so when we pass the diagnostic
                  // number we need to flip it based on size.
                  state->graph->registerOutput(
                      state->getOutput(output, out_stack.size() - i));
                  i++;
                }
                setTracingState(nullptr);
            
                if (getInlineEverythingMode()) {
                  Inline(*graph);
                }
                FixupTraceScopeBlocks(graph, self);
                NormalizeOps(graph);
                return {state, out_stack};
              } catch (...) {
                tracer::abandon();
                throw;
              }
            }

        那么具體記錄 operation 的過程發(fā)生在哪里呢?

        pytorch/torch/csrc/jit/runtime/register_c10_ops.cpp

            Operator createOperatorFromC10_withTracingHandledHere(
                const c10::OperatorHandle& op)
         
        {
              return Operator(op, [op](Stack& stack) {
                const auto input_size = op.schema().arguments().size();
                const auto output_size = op.schema().returns().size();
            
                Node* node = nullptr;
                std::shared_ptr<jit::tracer::TracingState> tracer_state;
            
                // trace the input before unwrapping, otherwise we may lose
                // the input information
                if (jit::tracer::isTracing()) {
                  # 獲取 tracer_state
                  tracer_state = jit::tracer::getTracingState();
                  auto symbol = Symbol::fromQualString(op.schema().name());
                  const auto& graph = tracer::getTracingState()->graph;
                  node = graph->create(symbol, 0);
                  tracer::recordSourceLocation(node);
                  const auto& args = op.schema().arguments();
                  int i = 0;
                  # 記錄args 
                  for (auto iter = stack.end() - input_size; iter != stack.end();
                       ++iter, ++i) {
                    // TODO we need to refactor graph APIs (e.g., addInputs)
                    // appropriately; after that, we can get rid of the giant if-else
                    // block we will clean this tech debt together in the following PRs
                    auto type = args[i].type();
                    if (type->kind() == TypeKind::OptionalType) {
                      if (iter->isNone()) {
                        Value* none = graph->insertNode(graph->createNone())->output();
                        node->addInput(none);
                        continue;
                      } else {
                        type = type->expect<OptionalType>()->getElementType();
                      }
                    }
                    if (type->isSubtypeOf(TensorType::get())) {
                      AT_ASSERT(iter->isTensor());
                      tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
                    } else if (type->kind() == TypeKind::FloatType) {
                      AT_ASSERT(iter->isDouble());
                      tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
                    } else if (type->kind() == TypeKind::IntType) {
                      AT_ASSERT(iter->isInt());
                      tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
                    } else if (type->kind() == TypeKind::BoolType) {
                      AT_ASSERT(iter->isBool());
                      tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
                    } else if (type->kind() == TypeKind::StringType) {
                      AT_ASSERT(iter->isString());
                      tracer::addInputs(node, args[i].name().c_str(), iter->toStringRef());
                    } else if (type->kind() == TypeKind::NumberType) {
                      tracer::addInputs(node, args[i].name().c_str(), iter->toScalar());
                    } else if (type->kind() == TypeKind::ListType) {
                      const auto& elem_type = type->expect<ListType>()->getElementType();
                      if (elem_type->isSubtypeOf(TensorType::get())) {
                        AT_ASSERT(iter->isTensorList());
                        auto list = iter->toTensorVector();
                        tracer::addInputs(node, args[i].name().c_str(), list);
                      } else if (elem_type->kind() == TypeKind::FloatType) {
                        AT_ASSERT(iter->isDoubleList());
                        // NB: now, tracer doesn't support tracing double list. We add
                        // special handling here, since in our case, we assume that all the
                        // doubles in the list are constants
                        auto value = iter->toDoubleVector();
                        std::vector<Value*> info(value.size());
                        for (size_t value_index = 0; value_index < value.size();
                             ++value_index) {
                          info[value_index] = graph->insertConstant(value[value_index]);
                          tracer::recordSourceLocation(info[value_index]->node());
                        }
                        node->addInput(
                            graph
                                ->insertNode(graph->createList(jit::FloatType::get(), info))
                                ->output());
                      } else if (elem_type->kind() == TypeKind::IntType) {
                        AT_ASSERT(iter->isIntList());
                        tracer::addInputs(
                            node, args[i].name().c_str(), iter->toIntVector());
                      } else if (elem_type->kind() == TypeKind::BoolType) {
                        AT_ASSERT(iter->isBoolList());
                        tracer::addInputs(
                            node, args[i].name().c_str(), iter->toBoolList().vec());
                      } else {
                        throw std::runtime_error(
                            "unsupported input list type: " + elem_type->str());
                      }
                    } else if (iter->isObject()) {
                      tracer::addInputs(node, args[i].name().c_str(), iter->toObject());
                    } else {
                      throw std::runtime_error("unsupported input type: " + type->str());
                    }
                  }
                  # node嵌入graph
                  graph->insertNode(node);
            
                  jit::tracer::setTracingState(nullptr);
                }

        可以看到,在具體運(yùn)算發(fā)生時(shí),會(huì)使用 getTracingState() 得到 forward 開始去創(chuàng)建的 state,然后看到根據(jù)op.schema().name() 得到計(jì)算類型(比如相加),根據(jù)計(jì)算類型通過 createNone 方法創(chuàng)建一個(gè)計(jì)算節(jié)點(diǎn),然后創(chuàng)建計(jì)算輸入,最后把計(jì)算node insert 到 graph 中,完成一次對(duì)計(jì)算的記錄。

        script

        因?yàn)?script 得到 IR 的方式是解析源碼,因此對(duì)于不同的代碼形式會(huì)略有不同(函數(shù),class,nn.Module的instance):1 Python 函數(shù) 簡(jiǎn)化后 code

            def script(obj, optimize=None, _frames_up=0, _rcb=None):
                # fucntion 分支
                if hasattr(obj, "__script_if_tracing_wrapper"):
                    obj = obj.__original_fn
                    _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
            
                # 檢查重載
                _check_directly_compile_overloaded(obj)
                # 是否之前被script過了
                maybe_already_compiled_fn = _try_get_jit_cached_function(obj)
                if maybe_already_compiled_fn:
                    return maybe_already_compiled_fn
                # 得到ast語(yǔ)法樹
                ast = get_jit_def(obj, obj.__name__)
                if _rcb is None:
                    _rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
                #c++ 入口,根據(jù)ast得到ir
                fn = torch._C._jit_script_compile(
                    qualified_name, ast, _rcb, get_default_args(obj)
                )
                # Forward docstrings
                fn.__doc__ = obj.__doc__
                # cache起來
                _set_jit_function_cache(obj, fn)
                return fn

        我們看下get_jit_def是如何得到 jit 規(guī)定的 ast 語(yǔ)法樹的

        僅保留邏輯代碼,細(xì)節(jié)刪掉

            def get_jit_def(fn, def_name, self_name=None):

                # 得到源代碼的一些信息
                sourcelines, file_lineno, filename = get_source_lines_and_file(fn, torch._C.ErrorReport.call_stack())
                sourcelines = normalize_source_lines(sourcelines)
                source =  dedent_src ''.join(sourcelines)
                # dedent_src 為包含了要script函數(shù)的字符串
                dedent_src = dedent(source)
                # 調(diào)用python ast包將字符串解析為Python的ast
                py_ast = ast.parse(dedent_src)
            
                # 得到python類型注釋
                type_line = torch.jit.annotations.get_type_line(source)
                #ctx中包含了函數(shù)所有原信息
                ctx = SourceContext(source, filename, file_lineno, leading_whitespace_len, True)
                fn_def = py_ast.body[0]
            
                # build_def將python 的ast 轉(zhuǎn)化為torchjit 使用的ast格式
                return build_def(ctx, fn_def, type_line, def_name, self_name=self_name)

        用一個(gè)簡(jiǎn)單的例子給大家解釋下 py_ast.body[0] 是什么

            import ast
            ... func_def= \
            ... """def test(a):
            ...     a = a + 2
            ...     return a + 1"""

            ... results = ast.parse(func_def)

        Python 解析出的 AST

        可見,ast.body 是一個(gè) list,其長(zhǎng)度等于解析的 string 中包含的函數(shù)的個(gè)數(shù),我們看第一個(gè)元素,其中 value 是一個(gè)

        Binop具體為一個(gè)Add,left 是Name類型,id為``a,right是Num,也就是2,這個(gè)Binop即解析的a = a + 2`。

        因?yàn)槲覀?get_source_lines_and_file 返回的一定是一個(gè) single top-level function, 因此我們直接取用第 0個(gè)元素,即 py_ast.body[0] 就可以了。

        接下來看build_def是如何將 Python 的 ast 轉(zhuǎn)化為自己需要的 ast 的。

        進(jìn)入buid_def

            def build_def(ctx, py_def, type_line, def_name, self_name=None):
                ....
                return Def(Ident(r, def_name),
                           decl,
                           build_stmts(ctx, body))

        因?yàn)?code style="font-size: 14px;overflow-wrap: break-word;padding: 2px 4px;border-radius: 4px;margin-right: 2px;margin-left: 2px;color: rgb(30, 107, 184);background-color: rgba(27, 31, 35, 0.05);font-family: "Operator Mono", Consolas, Monaco, Menlo, monospace;word-break: break-all;">ctx 包含 source code 所有信息, body 是 Python ast 解析結(jié)果,那么build_stmts中應(yīng)該包含我們想要的答案。

        我們用例子中a+2為例看會(huì)怎么轉(zhuǎn)換,這部分可見frontend.py

        關(guān)于StmtBuilder

            
            from torch._C._jit_tree_views import (
                ClassDef, Ident, Stmt, Decl, Def, Var,
                EmptyTypeAnnotation, Param, ExprStmt, Assign,
                Delete, Return, Raise, Assert, AugAssign, While,
                For, If, Pass, Break, Continue, Apply, Dots, Select,
                TrueLiteral, FalseLiteral, NoneLiteral, Starred,
                ListLiteral, TupleLiteral, DictLiteral, Const,
                StringLiteral, ListComp, Attribute, BinOp, UnaryOp,
                SliceExpr, Subscript, TernaryIf, With, WithItem, Property,
                DictComp,
            )
            # jit中定義的ast基本結(jié)構(gòu)
            
            def build_stmts(ctx, stmts):
                #發(fā)現(xiàn)其調(diào)用了`build_stmt`
                stmts = [build_stmt(ctx, s) for s in stmts]
                return list(filter(None, stmts))
            
            #`build_stmt` 是一個(gè)StmtBuilder()的instance
            build_stmt = StmtBuilder()
            build_expr = ExprBuilder()
            
            class Builder(object):
                def __call__(self, ctx, node):
                    # 可見會(huì)根據(jù)解析出的ast的類型返回相應(yīng)的build方法,從截圖可以看到`a+2`是一個(gè)`Assign`類型
                    # 因此會(huì)調(diào)用build_Assign
                    method = getattr(self, 'build_' + node.__class__.__name__, None)
                    if method is None:
                        raise UnsupportedNodeError(ctx, node)
                    return method(ctx, node)
            
            class StmtBuilder(Builder):
                @staticmethod
                def build_Assign(ctx, stmt):
                    # 截圖可以看到stmt.value是一個(gè)Binop
                    # build_expr是ExprBuilder的INSTANCE,其會(huì)調(diào)用`build_BinOp`
                    rhs = build_expr(ctx, stmt.value)
                    lhs = [build_expr(ctx, x) for x in stmt.targets]
                    return Assign(lhs, rhs)
            
                @staticmethod
                def build_Expr(ctx, stmt):
                    # Binop
                    value = stmt.value
                    if value.__class__.__name__ == 'Str':
                        # If a statement is a string literal expression,
                        # then it is a docstring. Just ignore it.
                        return None
                    else:
                        return ExprStmt(build_expr(ctx, value))
            
             class ExprBuilder(Builder):
                    binop_map = {
                    ast.Add: '+',
                    ast.Sub: '-',
                    ast.Mult: '*',
                    ast.Div: '/',
                    ast.Pow: '**',
                    ast.Mod: '%',
                    ast.FloorDiv: '//',
                    ast.BitAnd: '&',
                    ast.BitXor: '^',
                    ast.BitOr: '|',
                    ast.LShift: '<<',
                    ast.RShift: '>>',
                }
                    @staticmethod
                def build_BinOp(ctx, expr):
                    #expr.left是個(gè)`Name`調(diào)用build_Name
                    lhs = build_expr(ctx, expr.left)
                    rhs = build_expr(ctx, expr.right)
                    op = type(expr.op)
                    # 轉(zhuǎn)化為約定的代表運(yùn)算類型的string 符號(hào)
                    op_token = ExprBuilder.binop_map.get(op)
                    return BinOp(op_token, lhs, rhs)

        最終轉(zhuǎn)化為的格式,類似于S-expression.

            (def
              (ident test)
              (decl
                (list
                  (param
                    (ident a)
                    (option)
                    (option)
                    (False))
        )

                (option))

              (list
                (assign
                  (list (variable (ident a)))
                  (option
                    (+
                      (variable (ident a))
                      (const 2))
        )

                  (option))

                (return
                  (+
                    (variable (ident a))
                    (const 1))
        )
        )
        )

        好的,我們已經(jīng)得到得到j(luò)it約定的 AST 樹了,接下來我們要進(jìn)入 torch._C._jit_script_compile查看如何將這樣的 ast 樹轉(zhuǎn)化為 IR.

        C++ 入口為 script_compile_function

            static StrongFunctionPtr script_compile_function(
                const c10::QualifiedName& name,
                const Def& def,
                const FunctionDefaults& defaults,
                const ResolutionCallback& rcb)
         
        {
               #  def 中包含ast,跟著它就能找到答案
              auto cu = get_python_cu();
              #看來是get_python_cu這個(gè)類中的define函數(shù)完成的
              auto defined_functions = cu->define(
                  QualifiedName(name.prefix()),
                  /*properties=*/{},
                  /*propResolvers=*/{},
                  {def},
                  {pythonResolver(rcb)},
                  nullptr,
                  true);
              TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
              auto& defined = defined_functions[0];
              defined->setSchema(getSchemaWithNameAndDefaults(
                  def.range(), defined->getSchema(), def.name().name(), defaults));
              StrongFunctionPtr ret(std::move(cu), defined);
              didFinishEmitFunction(ret);
              return ret;
            }
            # 發(fā)現(xiàn)只是wapper了下CompilationUnit
            inline std::shared_ptr<CompilationUnit> get_python_cu() 
        {
              return py::module::import("torch.jit._state")
                  .attr("_python_cu")
                  .cast<std::shared_ptr<CompilationUnit>>();
            }
            
            #關(guān)于compilation_unit
            #/torch/csrc/jit/api/compilation_unit.h
             // for historic reasons, these are defined in ir_emitter.cpp
             // Returns the list of Functions just defined.
              std::vector<Function*> define(
                  const c10::optional<c10::QualifiedName>& prefix,
                  const std::vector<Property>& properties,
                  const std::vector<ResolverPtr>& propResolvers,
                  const std::vector<Def>& definitions,
                  const std::vector<ResolverPtr>&
                      defResolvers, /* determines how we handle free
                                 variables in each definition*/

                  // if non-null, the first argument to each def, is bound to this value
                  const Self* self,
                  // see [name mangling]
                  bool shouldMangle = false)
        ;
            #實(shí)現(xiàn)在torch/csrc/jit/frontend/ir_emitter.cpp
            std::unique_ptr<Function> CompilationUnit::define(
                const c10::optional<QualifiedName>& prefix,
                const Def& def,
                const ResolverPtr& resolver,
                const Self* self,
                const std::unordered_map<std::string, Function*>& function_table,
                bool shouldMangle)
         const 
        {
            
              auto _resolver = resolver;
              .....
              auto creator = [def, _resolver, self](Function& method) {
                ....
                ##核心代碼to_ir
                to_ir(def, _resolver, self, method);
              };
            
              auto fn = torch::make_unique<GraphFunction>(
                  std::move(name), std::make_shared<Graph>(), creator);
              return fn;
            }

        我們跟隨 def,找到了一個(gè)轉(zhuǎn)化為 IR 的關(guān)鍵的structto_ir,其輸入中有 def,也就是 ast,_resolver 是 Python 中傳過來的解析名字的函數(shù),我們可以在內(nèi)部找到關(guān)鍵部分

            to_ir(
                  const Def& def,
                  ResolverPtr resolver_,
                  const Self* self,
                  Function& method) // method being constructed
                  : method(method),
                    graph(method.graph()),
                    resolver(std::move(resolver_)),
                    typeParser_(resolver),
                    environment_stack(nullptr) {
                AT_ASSERT(resolver);
                pushFrame(graph->block(), /*starts_def=*/true);
            
                #emitDef 中會(huì)調(diào)用emitStatements
                method.setSchema(emitDef(def, self, graph->block()));
                ConvertToSSA(graph);
                CanonicalizeModifiedLoops(graph);
                NormalizeOps(graph);
                runCleanupPasses(graph);
              }
            private:
             #在to_ir 的private中我們可以看到Graph Function這些我們之前介紹的IR的組成部分
              Function& method;
              std::shared_ptr<Graph> graph;
              ResolverPtr resolver;
              std::unordered_map<int64_t, Value*> integral_constants;  
            
             #emitDef 中會(huì)調(diào)用emitStatements
             FunctionSchema emitDef(const Def& def, const Self* self, Block* block) 
        {
                ......
                // body
                auto stmts_list = def.statements();
                emitStatements(stmts_list.begin(), stmts_list.end());
                 ........
              }
             void emitStatements(
                  List<Stmt>::const_iterator begin,
                  List<Stmt>::const_iterator end)
         
        {
                for (; begin != end; ++begin) {
                  auto stmt = *begin;
                  ErrorReport::CallStack::update_pending_range(stmt.range());
                  switch (stmt.kind()) {
                    case TK_IF:
                      emitIf(If(stmt));
                      break;
                    case TK_WHILE:
                      emitWhile(While(stmt));
                      break;
                    case TK_FOR:
                      emitFor(For(stmt));
                      break;
                    case TK_ASSIGN:
                      emitAssignment(Assign(stmt));
                   .................
                      break;
                    default:
                      throw ErrorReport(stmt)
                          << "Unrecognized statement kind " << kindToString(stmt.kind());
                  }
                  // Found an exit statement in this block. The remaining statements aren't
                  // reachable so we don't emit them.
                  if (exit_blocks.count(environment_stack->block()))
                    return;
                }
              }


        我們可以看到根據(jù)stmt.kind(),會(huì)進(jìn)入而各種emit里面,其中一定可以找到
        graph->insertNode(graph->create(.....));
        類似的操作,對(duì)應(yīng)我們建立IR graph

        以上是我們以一個(gè) function 為例子,接下來我們以 script 一個(gè) module為例,其有一些獨(dú)有的挑戰(zhàn),因?yàn)橛幸恍┳兞康闹复?,是需要初始化后才知道的,同時(shí),我們希望 script 完的 module 對(duì)外還能保持一樣的接口,即可以正常訪問原有 module 的屬性,那么應(yīng)該怎么做呢?

        1. 在 module 原有的 init 結(jié)束后隨即開始完整的 script forward 函數(shù),替換涉及到的所有函數(shù)為 script 后的函數(shù)
        2. 如何正常訪問原有的屬性

        如何在一個(gè)類的 init 函數(shù)后面綁定行為呢,我們想到 metaclass,torch.jit 實(shí)現(xiàn)了 ScriptMeta這個(gè) metaclass。

        class MyModule(torch.jit.ScriptModule):
            @torch.jit.script_method
            def f(self.x):
                return x * x
            @torch.jit.script_method
            def forward(self, x):
                 return x + self.f(x)

        關(guān)于script_method

            def script_method(fn):
            
                _rcb = _jit_internal.createResolutionCallbackFromFrame(frames_up=2)
                ast = get_jit_def(fn, fn.__name__, self_name="ScriptModule")
                #暫時(shí)沒有script,只是返回包含ast的nametuple
                return ScriptMethodStub(_rcb, ast, fn)
            
                ScriptMethodStub = collections.namedtuple('ScriptMethodStub', ('resolution_callback''def_''original_method'))

        1. 移除所有script_method屬性被(@script_method修飾的方法),確保訪問到的是script function
        2. 修改module的_init_,確保module的self.param或者self.module初始化后立即編譯所有的script_method,從而生成的instance的forward已經(jīng)被替換

            class ScriptMeta(type):
                def __init__(cls, name, bases, attrs):  # noqa: B902
                    # cls ScriptMeta的instance,是一個(gè)類如ScriptModule
                    cls._methods: Dict[str, Any] = {}
                    cls._constants_set = set(getattr(cls, "__constants__", ()))
                    for base in reversed(bases):
                        # 還記得嗎t(yī)race的module也是有一個(gè)_methods的屬性
                        for k, v in getattr(base, "_methods", {}).items():
                            cls._methods[k] = v
                        base_constants = getattr(base, "_constants_set", set())
                        cls._constants_set = cls._constants_set.union(base_constants)
            
                    # 找到現(xiàn)在所有被@script_method修飾的方法,放到_method,并刪除原有attr
                    # init后之后統(tǒng)一script
                    for k, v in sorted(attrs.items()):
                        if isinstance(v, ScriptMethodStub):
                            delattr(cls, k)
                            cls._methods[v.original_method.__name__] = v


            
                    original_init = getattr(cls, "__init__"lambda self: None)
            
                    # 此處實(shí)現(xiàn)了init結(jié)束后,調(diào)用create_script_module進(jìn)行script
                    @functools.wraps(original_init)
                    def init_then_script(self, *args, **kwargs):
                        # 此處的self為instance
                        num_methods = len(cls._methods)
                        original_init(self, *args, **kwargs)
                        added_methods_in_init = len(cls._methods) > num_methods
            
                        if type(self) == cls:
                            # 選取需要script的method
                            def make_stubs(module):
                                cls = type(module)
                                if hasattr(cls, "_methods"):
                                    return [v for k, v in sorted(cls._methods.items())]
                                else:
                                    # infer_methods_to_compile 是一個(gè)選取要script函數(shù)的函數(shù)
                                    return infer_methods_to_compile(module)
                            # 講所有script_method一塊編譯為_actual_script_module屬性
            
                            self.__dict__[
                                "_actual_script_module"
                            ] = torch.jit._recursive.create_script_module(self, make_stubs, share_types=not added_methods_in_init)
            
                            # Delete the Python attributes that now shadow the ScriptModule
                            # ones, so that __getattr__ and __setattr__ will properly find
                            # the scripted versions.
                            concrete_type = self._actual_script_module._concrete_type
                            for name in concrete_type.get_attributes():
                                delattr(self, name)
                            for name, _ in concrete_type.get_modules():
                                delattr(self, name)
                            for name in ("_parameters""_buffers""_modules"):
                                delattr(self, name)
            
                    cls.__init__ = init_then_script  # type: ignore
            
                    return super(ScriptMeta, cls).__init__(name, bases, attrs)
            
              class _CachedForward(object):
                    def __get__(self, obj, cls):
                        return self.__getattr__("forward")  # type: ignore
            
               class ScriptModule(with_metaclass(ScriptMeta, Module)):  # type: ignore
            
                    def __init__(self):
                        super(ScriptModule, self).__init__()
            
                    forward = _CachedForward()
                    # 想訪問module的attr,返回_actual_script_module的attr
                    def __getattr__(self, attr):
                        if "_actual_script_module" not in self.__dict__:
                            return super(ScriptModule, self).__getattr__(attr)
                        return getattr(self._actual_script_module, attr)
            
                    def __setattr__(self, attr, value):
                        if "_actual_script_module" not in self.__dict__:
                            # Unwrap torch.jit.Attribute into a regular setattr + recording
                            # the provided type in __annotations__.
                            #
                            # This ensures that if we use the attr again in `__init__`, it
                            # will look like the actual value, not an instance of Attribute.
                            if isinstance(value, Attribute):
                                if "__annotations__" not in self.__class__.__dict__:
                                    self.__class__.__annotations__ = {}
                                self.__annotations__[attr] = value.type
                                value = value.value
                            return super(ScriptModule, self).__setattr__(attr, value)
            
                        setattr(self._actual_script_module, attr, value)

        關(guān)于 create_script_module 函數(shù)會(huì) script method 然后返回一個(gè)RecursiveScriptModule,但是其邏輯較為復(fù)雜,在此不再展開。

        關(guān)于 getattribute vs getattr

        當(dāng)訪問某個(gè)實(shí)例屬性時(shí),getattribute 會(huì)被無條件調(diào)用,當(dāng)這個(gè)屬性不存在,則會(huì)調(diào)用 getattr,如未實(shí)現(xiàn)自己的 getattr 方法,會(huì)拋出AttributeError 提示找不到這個(gè)屬性,如果自定義了自己 getattr 方法的話方法會(huì)在這種找不到屬性的情況下被調(diào)用。

        4 IR優(yōu)化的簡(jiǎn)單介紹

        jit 一般涉及如下優(yōu)化: loop unrolling peephole optimization constant propagation DCE fusion inlining... 我們看如下例子:

            def test(x):
                # Dead code Elimination
                for i in range(1000):
                    y = x + 1
                for i in range(100):
                    #peephole optimization
                    x = x.t()
                    x = x.t()
                return x.sum()
            
            opt_test = torch.jit.script(test)
            s = time()
            inputs = torch.ones(4,4).cuda()
            s = time()
            for i in range(10000):
                test(inputs)
            print(time()-s)
            # 95s
            s = time()
            for i in range(10000):
                opt_test(inputs)
            print(time()-s)
            # 0.13s
            print(opt_test.graph)
            print(opt_test.graph_for(inputs))
            95.13823795318604
            0.13010907173156738
            graph(%x.1 : Tensor):
              %22 : None = prim::Constant()
              %13 : bool = prim::Constant[value=1]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4
              %10 : int = prim::Constant[value=100]() # /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:19
              %x : Tensor = prim::Loop(%10, %13, %x.1# /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:10:4
                block0(%i : int, %x.10 : Tensor):
                  %x.4 : Tensor = aten::t(%x.10# /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:11:12
                  %x.7 : Tensor = aten::t(%x.4# /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:12:12
                  -> (%13, %x.7)
              %23 : Tensor = aten::sum(%x, %22# /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11
              return (%23)
            
            graph(%x.1 : Tensor):
              %1 : None = prim::Constant()
              %2 : Tensor = aten::sum(%x.1, %1# /home/SENSETIME/zhangshilong/PycharmProjects/pythonProject/opt.py:13:11
              return (%2)

        關(guān)于 IR 計(jì)算圖優(yōu)化

        IR 的 Method 中內(nèi)置 GraphExecutor object,創(chuàng)建于第一次執(zhí)行的時(shí)候,負(fù)責(zé)優(yōu)化。
        文件 pytorch-master/torch/csrc/jit/api/method.h scritp_method 的 C++ 原型里

            GraphExecutor& get_executor() {
                return function_->get_executor();
              }

        GraphExecutor 的定義在/torch/csrc/jit/runtime/graph_executor.cpp,可見其由 graph 產(chǎn)生,定義了 run 方法執(zhí)行

            GraphExecutor::GraphExecutor(
                const std::shared_ptr<Graph>& graph,
                std::string function_name)
                : pImpl(
                      IsNewExecutorEnabled()
                          ? dynamic_cast<GraphExecutorImplBase*>(
                                new ProfilingGraphExecutorImpl(
                                    graph,
                                    std::move(function_name)))
                          : dynamic_cast<GraphExecutorImplBase*>(
                                new GraphExecutorImpl(graph, std::move(function_name)))) {}
            std::shared_ptr<Graph> GraphExecutor::graph() const {
              return pImpl->graph;
            }
            const ExecutionPlan& GraphExecutor::getPlanFor(
                Stack& inputs,
                size_t remaining_bailout_depth)
         
        {
              return pImpl->getPlanFor(inputs, remaining_bailout_depth);
            }
            
             std::shared_ptr<GraphExecutorImplBase> pImpl;
            .....

        關(guān)于 GraphExecutorImplBase,/torch/csrc/jit/runtime/graph_executor.cpp


            const ExecutionPlan& getOrCompile(const Stack& stack) 
        {
                  .....
                  auto plan = compileSpec(spec);
            
                }
              }
            # compileSpec 會(huì)返回一個(gè)plan
            ExecutionPlan compileSpec(const ArgumentSpec& spec) 
        {
                auto opt_graph = graph->copy();
                GRAPH_DUMP("Optimizing the following function:", opt_graph);
                arg_spec_creator_.specializeTypes(*opt_graph, spec);
            
                // Phase 0. Inline functions, then clean up any artifacts that the inliner
                //          left in that may inhibit optimization
                 .....
                runRequiredPasses(opt_graph);
                GRAPH_DEBUG(
                    "After runRequiredPasses, before ConstantPropagation\n", *opt_graph);
            
                // Phase 2. Propagate detailed information about the spec through the
                //          graph (enabled more specializations in later passes).
                //          Shape propagation sometimes depends on certain arguments being
                //          constants, and constant propagation doesn't need shape
                //          information anyway, so it's better to run it first.
                ConstantPropagation(opt_graph);
                GRAPH_DEBUG(
                    "After ConstantPropagation, before PropagateInputShapes\n", *opt_graph);
                PropagateInputShapes(opt_graph);
                GRAPH_DEBUG(
                    "After PropagateInputShapes, before PropagateRequiresGrad\n",
                    *opt_graph);
                PropagateRequiresGrad(opt_graph);
                GRAPH_DEBUG(
                    "After PropagateRequiresGrad, before runOptimization\n", *opt_graph);
            
                // Phase 3. Run differentiable optimizations (i.e. simple graph rewrites
                //          that we can still execute using autograd).
                runOptimization(opt_graph);
                .....各種優(yōu)化
                return ExecutionPlan(opt_graph, function_name_);
              }

        這些優(yōu)化在 torch/csrc/jit/passes/ 文件夾

        torch/csrc/jit/passes/dead_code_elimination.cpp

        /torch/csrc/jit/passes/fuse_linear.cpp

        torch/csrc/jit/passes/remove_dropout.cpp

        torch/csrc/jit/passes/fold_conv_bn.cpp

        參考

        1. INTRODUCTION TO TORCHSCRIPT

        2. PyTorch 部署_TorchScript

        3.pytorch_wiki

        4. PyTorch-JIT-Source-Code-Read-Note

        5. Abstract_syntax_tree


        - The End -


        GiantPandaCV

        長(zhǎng)按二維碼關(guān)注我們

        本公眾號(hào)專注:

        1. 技術(shù)分享;

        2. 學(xué)術(shù)交流;

        3. 資料共享

        歡迎關(guān)注我們,一起成長(zhǎng)!

        瀏覽 60
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            一级AAAAAA毛片免费 | 欧美一本乱大交性XXXⅩ | 91尤物国产网红尤物色大师 | 我和工厂少妇的性系列 | 乱亲h女秽乱长久久久 | 牛牛操逼 | 少妇高潮喷水 | 天干夜天干天天天爽视频 | 成人视频在线观看高清无码18 | 黄色3级网站 |