1. mlc-llm 推理優(yōu)化和大語言模型搭建解析(文末送書)

        共 95254字,需瀏覽 191分鐘

         ·

        2023-09-27 02:52

        在這里插入圖片描述

        0x0. 前言

        本文解析一下mlc-llm(https://github.com/mlc-ai/mlc-llm)對大模型推理的流程以及使用的圖優(yōu)化,算子優(yōu)化策略。mlc-llm的模型部署流程可以查看官方文檔:https://mlc.ai/mlc-llm/docs/ ,也可以參考我前段時間寫的這篇MLC-LLM 部署RWKV World系列模型實戰(zhàn)(3B模型Mac M2解碼可達26tokens/s) 。

        此外,閱讀mlc-llm的代碼還需要理解一些TVM Unify的一些基礎(chǔ)概念,可以參考TVM 學習指南(個人版) ,Relax: TVM 的下一代圖層級 IR,新一代深度學習編譯技術(shù)變革和展望等等。從 https://github.com/BBuf/tvm_mlir_learn 這里可以查看更多相關(guān)博客和資料。

        MLC-LLM 部署RWKV World系列模型實戰(zhàn)(3B模型Mac M2解碼可達26tokens/s) 中提到要使用mlc-llm部署模型首先需要一個編譯過程,將原始的基于Realx搭建的模型比如RWKV和給定的device信息一起編譯為TVM中的runtime.Module(在linux上編譯的產(chǎn)物就是.so文件)提供mlc-llm的c++推理接口調(diào)用 。我們就從這里看起:

        由于mlc-llm上游更新很快,為了準確標定代碼位置我fork了一份2023年9月17號的mlc-llm代碼 :https://github.com/BBuf/mlc-llm-code-analysis ,本文的注釋以及指出的代碼位置均以這個fork倉庫為準。

        0x1. 編譯流程解析

        編譯的入口在:https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/build.py 。

        這個腳本構(gòu)建了一個模型build的入口,可以通過傳入不同的參數(shù)來構(gòu)建不同配置的模型。參數(shù)解析和模型編譯都在 https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py 中實現(xiàn),模型編譯準備(mod_transform_before_build函數(shù))和編譯(build函數(shù))兩個階段。在模型編譯準備階段,包含準備需要優(yōu)化的算子,執(zhí)行一些基礎(chǔ)的圖變換,針對cuda做進一步優(yōu)化,做算子fuse等優(yōu)化,詳細的解釋清閱讀這里的注釋:https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py#L378 。

        在這之后會執(zhí)行編譯過程:https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py#L378 。從這里我們可以看到,對于GPU來說使用的是默認的schedule模板,并沒有使用AutoTVM/Ansor等等調(diào)優(yōu)工具,這一點是很友好的,個人猜測也是因為Transformer架構(gòu)的模型是很固定的,然后優(yōu)化方法也比較統(tǒng)一。

        上面的編譯前準備和編譯都是針對IRModule來說的,那么這個IRModule是怎么來的呢?以及量化是在哪里做的?這兩個問題都是在 build_model_from_args 函數(shù): https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/core.py#L627 處理的,發(fā)生在 mod_transform_before_build 函數(shù)調(diào)用之前。以 RWKV 模型為例,通過這行 mod, param_manager, params, model_config = rwkv.get_model(args, config) 代碼完成了從原始的 huggingface 模型到初始的 IRModule 的轉(zhuǎn)換,在這個過程中也包含了量化。

        0x2. 模型搭建解析

        0x2.1 模型組件搭建

        首先在 https://github.com/BBuf/mlc-llm-code-analysis/blob/main/mlc_llm/relax_model/modules.py 這里基于Relax的內(nèi)部接口(relax.Expr,relax.testing.nn.Module,relax.op.xxx等等)定義了搭建LLM模型需要的一些組件比如 ModuleListLinearEmbedding,LayerNorm,RotaryEmbedding等等。這個地方我添加了一些解釋,請點上面的源碼鏈接查看。然后這個地方需要注意2個特殊的op,第一個是來自 https://github.com/mlc-ai/relax/blob/ceaf7b0156524d30537a3de5fa30764eaff4edb8/python/tvm/relax/op/index.py#L28 的:


        def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr:
            return _ffi_api.take(x, indices, axis)  # type: ignore

        這個函數(shù),實現(xiàn)了take的核心功能,與numpy和pytorch的take語義類似,都可以通過指定indices來從輸入張量中抽取值。主要調(diào)用了_ffi_api.take進行取值操作, 這個_ffi_api是relax底層實現(xiàn), take操作的實際計算會在這里進行。這個函數(shù)被用于Embedding組件的搭建中。

        另外nn.emit這個接口的作用是將一個relax.Expr表達式轉(zhuǎn)化為relax.Var變量,并保存該變量。

        最后我們注意到這里搭建的Relax模塊風格和PyTorch的模塊風格基本一致,也可以看出Relax前端是不斷靠近動態(tài)圖風格,追求更佳的易用性。

        0x2.2 模型搭建

        首先看一些準備工作:

        # @dataclass:這個裝飾器用于指示RWKVConfig類是一個數(shù)據(jù)類。用于存儲RWKVModel的配置信息。
        @dataclass
        class RWKVConfig:
            """The configuration class to store the configuration of a `RWKVModel`."""

            num_hidden_layers: int # 類中的一個屬性,用于存儲隱藏層的數(shù)量,類型為整數(shù)。
            vocab_size: int # 類中的一個屬性,用于存儲詞匯表的大小,類型為整數(shù)。
            hidden_size: int # 類中的一個屬性,用于存儲隱藏層的大小,類型為整數(shù)。
            intermediate_size: int # 類中的一個屬性,用于存儲中間層的大小,類型為整數(shù)。
            rescale_every: int = 0 # 類中的一個屬性,默認值為0,用于存儲重新縮放的頻率,類型為整數(shù)。
            layer_norm_epsilon: float = 1e-5 # 類中的一個屬性,默認值為1e-5,用于存儲層歸一化的epsilon值,類型為浮點數(shù)。
            max_sequence_length: int = 1024 # 類中的一個屬性,默認值為1024,用于存儲最大序列長度,類型為整數(shù)。
            dtype: str = "float32" # 類中的一個屬性,默認值為"float32",用于存儲數(shù)據(jù)類型,類型為字符串。

            def __init__(
                self,
                num_hidden_layers: int,
                vocab_size: int,
                hidden_size: int,
                intermediate_size: int,
                rescale_every: int = 0,
                layer_norm_epsilon: float = 1e-5,
                context_length: int = 1024,
                dtype: str = "float32",
                **kwargs,
            )
         -> None:

                self.num_hidden_layers = num_hidden_layers
                self.vocab_size = vocab_size
                self.hidden_size = hidden_size
                self.intermediate_size = intermediate_size
                self.rescale_every = rescale_every
                self.layer_norm_epsilon = layer_norm_epsilon
                self.max_sequence_length = context_length
                self.dtype = dtype
                self.kwargs = kwargs

        # 用來索引RWKV的Attention和FFN部分存儲的狀態(tài)或者叫Cache。
        # python代碼可以參考: https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L858-L867
        class State:
            ATT_X = 0
            ATT_A = 1
            ATT_B = 2
            ATT_P = 3
            FFN_X = 4

        這里的State是用來索引RWKV的Attention和FFN部分存儲的狀態(tài)或者叫Cache,每一個Layer有5個不同的State,并且每個State的shape都是[1, hidden_size],這里的1代表的應(yīng)該是batch緯度。

        # 義了一個名為_load_state的函數(shù),它接受一個名為state的參數(shù),類型為Expr,一個名為hidden_size的參數(shù),類型為整數(shù),
        # 一個名為dtype的參數(shù),類型為字符串。函數(shù)的返回類型為Expr。
        def _load_state(state: Expr, hidden_size: int, dtype: str) -> Expr:
            # Reuse `attention_kv_cache_view`
            # 將外部函數(shù)vm.builtin.attention_kv_cache_view賦值給變量f_load_cache。relax.extern是一個外部函數(shù)調(diào)用的語法,
            # 它指示編譯器在編譯時將該函數(shù)調(diào)用轉(zhuǎn)換為相應(yīng)的外部函數(shù)調(diào)用。
            f_load_cache = relax.extern("vm.builtin.attention_kv_cache_view")
            # 使用nn.emit方法生成一個表達式對象,該表達式表示對外部函數(shù)f_load_cache的調(diào)用。
            # 調(diào)用的參數(shù)是一個列表,包含state和R.shape([1, hidden_size]),以及sinfo_args參數(shù)指定的一個R.Tensor對象。
            cache = nn.emit(
                relax.Call(
                    f_load_cache,
                    [state, R.shape([1, hidden_size])],
                    sinfo_args=[R.Tensor((1, hidden_size), dtype)],
                )
            )
            return cache

        # 定義了一個名為_store_state的函數(shù),它接受一個名為state的參數(shù),類型為Expr,一個名為value的參數(shù),類型為Expr。
        def _store_state(state: Expr, value: Expr):
            # Reuse `attention_kv_cache_update`
            # 將外部函數(shù)vm.builtin.attention_kv_cache_update賦值給變量f_store_cache。
            # relax.extern是一個外部函數(shù)調(diào)用的語法,它指示編譯器在編譯時將該函數(shù)調(diào)用轉(zhuǎn)換為相應(yīng)的外部函數(shù)調(diào)用。
            f_store_cache = relax.extern("vm.builtin.attention_kv_cache_update")

            # 使用nn.emit方法生成一個表達式對象,該表達式表示對外部函數(shù)f_store_cache的調(diào)用。
            # 調(diào)用的參數(shù)是一個列表,包含state和value,以及sinfo_args參數(shù)指定的一個R.Object()對象。
            return nn.emit(
                relax.Call(
                    f_store_cache,
                    [state, value],
                    sinfo_args=[R.Object()],
                )
            )

        這兩個函數(shù)用來加載和存儲RWKV模型的State。接下來看一下對應(yīng) https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L741 這里的torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp) 的Relax實現(xiàn),為了方便對照先貼一下原始的wkv forward cuda kernel:

        template <typename F>
        __global__ void kernel_wkv_forward(const int B, const int T, const int C,
                                       const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
                                       F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
            const int idx = blockIdx.x * blockDim.x + threadIdx.x;
            const int _b = idx / C;
            const int _c = idx % C;
            const int _offset = _b * T * C + _c;
            const int _state_offset = _b * C + _c;

            float u = _u[_c];
            float w = _w[_c];
            const F *__restrict__ const k = _k + _offset;
            const F *__restrict__ const v = _v + _offset;
            F *__restrict__ const y = _y + _offset;

            float aa = _aa[_state_offset];
            float bb = _bb[_state_offset];
            float pp = _pp[_state_offset];
            for (int i = 0; i < T; i++) {
                const int ii = i * C;
                const float kk = float(k[ii]);
                const float vv = float(v[ii]);
                float ww = u + kk;
                float p = max(pp, ww);
                float e1 = exp(pp - p);
                float e2 = exp(ww - p);
                y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
                ww = w + pp;
                p = max(ww, kk);
                e1 = exp(ww - p);
                e2 = exp(kk - p);
                aa = e1 * aa + e2 * vv;
                bb = e1 * bb + e2;
                pp = p;
            }
            _aa[_state_offset] = aa;
            _bb[_state_offset] = bb;
            _pp[_state_offset] = pp;
        }

        template <typename F>
        void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
            dim3 threadsPerBlock( min(C, 32) );
            assert(B * C % threadsPerBlock.x == 0);
            dim3 numBlocks(B * C / threadsPerBlock.x);
            kernel_wkv_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, aa, bb, pp);
        }

        這個cuda kernel里面,B表示batch_size,在mlc-llm的實現(xiàn)默認為1。然后T表示序列長度,C表示隱藏層緯度。然后我們就可以對應(yīng)來看mlc-llm的wkv實現(xiàn)了。

        # 定義了一個名為create_wkv_func的函數(shù),它接受一個名為hidden_size的參數(shù),
        # 類型為整數(shù),一個名為dtype的參數(shù),類型為字符串,一個名為out_dtype的參數(shù),類型為字符串。
        def create_wkv_func(hidden_size: int, dtype: str, out_dtype: str):
            @T.prim_func
            def wkv_func(
                k: T.handle,
                v: T.handle,
                time_decay: T.handle,
                time_first: T.handle,
                saved_a: T.handle,
                saved_b: T.handle,
                saved_p: T.handle,
                wkv: T.handle,
                out_a: T.handle,
                out_b: T.handle,
                out_p: T.handle,
            )
        :

                # 設(shè)置TIR函數(shù)的屬性。這里設(shè)置了三個屬性,包括op_pattern、tir.noalias和tir.is_scheduled。
                T.func_attr({"op_pattern"8"tir.noalias"True"tir.is_scheduled"1})
                # 聲明一個名為context_length的變量,類型為T.int64(),用于存儲上下文長度。
                context_length = T.int64()
                # 創(chuàng)建一個名為K的匹配緩沖區(qū),通過T.match_buffer方法匹配參數(shù)k的形狀和數(shù)據(jù)類型。
                # K的形狀在原始的ChatRWKV中為B,T,C,只不過這里B=1
                # 這里的k就是上面cuda kernel的_k
                K = T.match_buffer(k, (context_length, hidden_size), dtype=dtype)
                # 創(chuàng)建一個名為V的匹配緩沖區(qū),通過T.match_buffer方法匹配參數(shù)v的形狀和數(shù)據(jù)類型。
                # 這里的v就是上面cuda kernel的_v
                V = T.match_buffer(v, (context_length, hidden_size), dtype=dtype)
                # 創(chuàng)建一個名為TimeDecay的匹配緩沖區(qū),通過T.match_buffer方法匹配參數(shù)time_decay的形狀和數(shù)據(jù)類型。
                # 這里的TimeDecay就是上面的w
                TimeDecay = T.match_buffer(time_decay, (hidden_size,), dtype=dtype)
                # 創(chuàng)建一個名為TimeFirst的匹配緩沖區(qū),通過T.match_buffer方法匹配參數(shù)time_first的形狀和數(shù)據(jù)類型。
                # 這里的TimeFirst對應(yīng)上面的u
                TimeFirst = T.match_buffer(time_first, (hidden_size,), dtype=dtype)
                # 對應(yīng)kernel里面的_aa的上一個token的狀態(tài)
                SavedA = T.match_buffer(saved_a, (1, hidden_size), dtype=dtype)
                # 對應(yīng)kernel里面的_bb的上一個token的狀態(tài)
                SavedB = T.match_buffer(saved_b, (1, hidden_size), dtype=dtype)
                # 對應(yīng)kernel里面的_pp的上一個token的狀態(tài)
                SavedP = T.match_buffer(saved_p, (1, hidden_size), dtype=dtype)
                # 對應(yīng)_aa的當前token狀態(tài)
                OutA = T.match_buffer(out_a, (1, hidden_size), dtype=dtype)
                # 對應(yīng)_bb的當前token狀態(tài)
                OutB = T.match_buffer(out_b, (1, hidden_size), dtype=dtype)
                # 對應(yīng)_pp的當前token狀態(tài)
                OutP = T.match_buffer(out_p, (1, hidden_size), dtype=dtype)

                # 對應(yīng)kernel里面的p
                P = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local")
                # 對應(yīng)kernel里面的e1
                E1 = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local")
                # 對應(yīng)kernel里面的e2
                E2 = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local")
                # 對應(yīng)kernel里面的aa
                A_local = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local")
                # 對應(yīng)kernel里面的bb
                B_local = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local")
                # 對應(yīng)kernel里面的cc
                P_local = T.alloc_buffer((hidden_size,), dtype=dtype, scope="local")

                # 迭代hidden_size // 32次,使用T.thread_binding方法進行線程綁定,其中hidden_size // 32是塊索引的范圍。
                # 這里的線程塊劃分和rwkv kernel里面保持一致:即每個block 32個線程,一共((B=1)*C)/32個blcok
                for bx in T.thread_binding(hidden_size // 32, thread="blockIdx.x"):
                    # 迭代32次,使用T.thread_binding方法進行線程綁定,其中32是線程索引的范圍。
                    for tx in T.thread_binding(32, thread="threadIdx.x"):
                        # 創(chuàng)建一個名為"init"的塊,用于初始化局部變量。
                        with T.block("init"):
                            # 對應(yīng) const int _state_offset = _b * C + _c;
                            vi = T.axis.S(hidden_size, bx * 32 + tx)
                            # 對應(yīng) float aa = _aa[_state_offset];
                            A_local[vi] = SavedA[0, vi]
                            # 對應(yīng) float bb = _bb[_state_offset];
                            B_local[vi] = SavedB[0, vi]
                            # 對應(yīng) float pp = _pp[_state_offset];
                            P_local[vi] = SavedP[0, vi]
                        for j in range(context_length): # 對應(yīng) for (int i = 0; i < T; i++)
                            with T.block("main"):
                                # 對應(yīng) const int _state_offset = _b * C + _c;
                                vi = T.axis.S(hidden_size, bx * 32 + tx)
                                # vj 對應(yīng) _b * T; [vj, vi] = _b * T * C + _b * C + _c
                                # _b * T * C + _c = _offset
                                vj = T.axis.opaque(context_length, j)
                                # 對應(yīng) float p = max(pp, ww); float ww = u + kk; 
                                # const float kk = float(k[ii]); const int ii = i * C;
                                # const F *__restrict__ const k = _k + _offset;
                                P[vi] = T.max(P_local[vi], K[vj, vi] + TimeFirst[vi])
                                # 對應(yīng) float e1 = exp(pp - p);
                                E1[vi] = T.exp(P_local[vi] - P[vi])
                                # 對應(yīng) float e2 = exp(ww - p);
                                E2[vi] = T.exp(K[vj, vi] + TimeFirst[vi] - P[vi])

                                P[vi] = T.max(P_local[vi] + TimeDecay[vi], K[vj, vi])
                                E1[vi] = T.exp(P_local[vi] + TimeDecay[vi] - P[vi])
                                E2[vi] = T.exp(K[vj, vi] - P[vi])
                                A_local[vi] = E1[vi] * A_local[vi] + E2[vi] * V[vj, vi]
                                B_local[vi] = E1[vi] * B_local[vi] + E2[vi]
                                P_local[vi] = P[vi]

                        with T.block("write_back"):
                            vi = T.axis.S(hidden_size, bx * 32 + tx) # 對應(yīng) 
                            OutA[0, vi] = A_local[vi] # 對應(yīng) _aa[_state_offset] = aa;
                            OutB[0, vi] = B_local[vi] # 對應(yīng) _bb[_state_offset] = bb;
                            OutP[0, vi] = P_local[vi] # 對應(yīng) _pp[_state_offset] = pp;

            return wkv_func

        我們可以看到mlc-llm里面的wkv forward實現(xiàn)基本就是用基于Relax的api將cuda函數(shù)翻譯成了TIR。注釋里面給了一些下標的推導以及每一行Relax的代碼是如何對應(yīng)到原始的cuda kernel。

        # 定義了一個名為_te_concat_saved_x的函數(shù),它接受兩個參數(shù)saved_x和x,都是te.Tensor類型的張量。
        # 使用TVM的te.compute函數(shù)計算一個新的張量,該張量的形狀與x相同,元素根據(jù)條件判斷進行選擇。如果i等于0,
        # 則選擇saved_x[0, j]作為元素值,否則選擇x[i - 1, j]作為元素值。其中i和j是迭代變量。
        def _te_concat_saved_x(saved_x: te.Tensor, x: te.Tensor):
            return te.compute(
                x.shape,
                lambda i, j: tir.if_then_else(i == 0, saved_x[0, j], x[i - 1, j]),
            )

        # 定義了一個名為_te_get_last_x的函數(shù),它接受一個參數(shù)x,是一個te.Tensor類型的張量。
        # a. seq_len, hidden_size = x.shape:獲取x張量的形狀,其中seq_len表示序列長度,hidden_size表示隱藏大小。
        # b. return te.compute(...):使用TVM的te.compute函數(shù)計算一個新的張量,該張量的形狀為(1, hidden_size),
        # 元素值為x[seq_len - 1, j],其中j是迭代變量。
        def _te_get_last_x(x: te.Tensor):
            seq_len, hidden_size = x.shape
            return te.compute((1, hidden_size), lambda _, j: x[seq_len - 1, j])

        這兩個函數(shù)應(yīng)該對應(yīng)了 https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L455 這里代碼里面的sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))xx[-1, :]

        @MyFunction
            def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry):
                xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
                sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
                kx = xx * k_mix + sx * (1 - k_mix)
                rx = xx * r_mix + sx * (1 - r_mix)

                r = torch.sigmoid(gemm(rx, rw))
                vx = torch.square(torch.relu(gemm(kx, kw)))
                out = r * gemm(vx, vw)
                return x + out, xx[-1,:]

        接著對Embedding函數(shù)進行解析:

        # 定義了一個名為RWKV_Embedding的PyTorch模塊。
        class RWKV_Embedding(nn.Module):
            # 定義了RWKV_Embedding類的構(gòu)造函數(shù),接受三個參數(shù)num_embeddings、embedding_dim和dtype。
            def __init__(self, num_embeddings, embedding_dim, dtype):
                self.num_embeddings = num_embeddings # 將num_embeddings賦值給類成員變量self.num_embeddings。
                self.embedding_dim = embedding_dim # 將embedding_dim賦值給類成員變量self.embedding_dim。
                # 創(chuàng)建一個名為weight的Parameter,形狀為(num_embeddings, embedding_dim),
                # 數(shù)據(jù)類型為dtype,并將其賦值給類成員變量self.weight。
                self.weight = nn.Parameter(
                    (num_embeddings, embedding_dim), dtype=dtype, name="weight"
                )

            def forward(self, x: relax.Expr) -> relax.Var:
                # 調(diào)用op.reshape函數(shù)將輸入張量x進行reshape,將其展平為一維張量,并將結(jié)果重新賦值給x。
                # nn.emit是將一個relax.Expr表達式轉(zhuǎn)化為relax.Var變量,并保存該變量。
                x = nn.emit(op.reshape(x, shape=[-1]))
                # 使用op.take操作從self.weight中按照索引x提取對應(yīng)的嵌入向量,并返回結(jié)果。這里的axis=0表示在第一個維度上進行索引操作。
                return nn.emit(op.take(self.weight, x, axis=0))

        以及LayerNorm:

        # 這段代碼定義了一個名為RWKV_LayerNorm的PyTorch模塊,它實現(xiàn)了一個Layer Normalization層。
        class RWKV_LayerNorm(nn.Module):
            # 定義了RWKV_LayerNorm類的構(gòu)造函數(shù),接受四個參數(shù)intermediate_size、dtype、eps和name_prefix。
            def __init__(self, intermediate_size, dtype, eps=1e-5, name_prefix=""):
                super().__init__()
                self.eps = eps
                self.weight = nn.Parameter(
                    (intermediate_size,), dtype=dtype, name=f"{name_prefix}_ln_weight"
                )
                self.bias = nn.Parameter(
                    (intermediate_size,), dtype=dtype, name=f"{name_prefix}_ln_bias"
                )

            def forward(self, x: relax.Expr) -> relax.Var:
                # 使用op.nn.layer_norm操作對輸入張量x進行Layer Normalization,其中使用Parameter self.weight作為縮放參數(shù)(gamma),
                # 使用可學習參數(shù)self.bias作為偏移參數(shù)(beta),在最后一個維度(axes=-1)上進行標準化操作,
                # 并設(shè)置小數(shù)值修正項為self.eps。將標準化后的結(jié)果重新賦值給x。
                x = nn.emit(
                    op.nn.layer_norm(
                        x,
                        gamma=self.weight,
                        beta=self.bias,
                        axes=-1,
                        epsilon=self.eps,
                    )
                )
                return x

        接著對FFN層做一個詳細的解析:

        # 這段代碼定義了一個名為RWKV_FFN的PyTorch模塊,它實現(xiàn)了Feed-Forward Network(FFN)。
        class RWKV_FFN(nn.Module):
            # 定義了RWKV_FFN類的構(gòu)造函數(shù),接受兩個參數(shù)RWKVConfig和index。
            def __init__(self, config: RWKVConfig, index: int) -> None:
                super().__init__()
                # 將config.hidden_size賦值給類成員變量self.hidden_size,表示隱藏大小。
                self.hidden_size = config.hidden_size
                # 將config.dtype賦值給類成員變量self.dtype,表示數(shù)據(jù)類型。
                self.dtype = config.dtype
                # 將index賦值給類成員變
                self.index = index
                # 建一個名為time_mix_key的可學習參數(shù),形狀為(self.hidden_size,),
                # 數(shù)據(jù)類型為config.dtype,命名為"ffn_{index}_time_mix_k",并將其賦值給類成員變量self.time_mix_key。
                self.time_mix_key = nn.Parameter(
                    (self.hidden_size,), dtype=config.dtype, name=f"ffn_{index}_time_mix_k"
                )
                # 創(chuàng)建一個名為time_mix_receptance的可學習參數(shù),形狀為(self.hidden_size,),數(shù)據(jù)類型為config.dtype,
                # 命名為"ffn_{index}_time_mix_r",并將其賦值給類成員變量self.time_mix_receptance。
                self.time_mix_receptance = nn.Parameter(
                    (self.hidden_size,), dtype=config.dtype, name=f"ffn_{index}_time_mix_r"
                )
                # 創(chuàng)建一個線性層,輸入大小為self.hidden_size,輸出大小為config.intermediate_size,
                # 數(shù)據(jù)類型為config.dtype,沒有偏置項,并將其賦值給類成員變量self.key。
                self.key = Linear(
                    self.hidden_size, config.intermediate_size, dtype=config.dtype, bias=False
                )
                # 創(chuàng)建一個線性層,輸入大小為self.hidden_size,輸出大小為self.hidden_size,數(shù)據(jù)類型為config.dtype,
                # 沒有偏置項,并將其賦值給類成員變量self.receptance。
                self.receptance = Linear(
                    self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False
                )
                self.value = Linear(
                    config.intermediate_size, self.hidden_size, dtype=config.dtype, bias=False
                )

            def forward(self, x: Expr, state: Expr) -> Expr:
                # 計算偏移量,用于在state中獲取對應(yīng)的保存狀態(tài)。
                offset = self.index * 5 + State.FFN_X
                # 獲取x的shape[0]表示上下文長度。
                context_length = x.struct_info.shape[0]
                # 獲取隱藏層大小。
                hidden_size = self.hidden_size

                # 調(diào)用_load_state函數(shù)從state中加載保存的狀態(tài)state[offset],并將結(jié)果賦值給saved_x。
                saved_x = _load_state(state[offset], hidden_size, self.dtype)
                # 如果上下文長度不為1,則執(zhí)行下面的操作。
                if not is_one(context_length):
                    # 調(diào)用nn.emit_te函數(shù),將saved_x和x作為參數(shù)傳遞給
                    # _te_concat_saved_x函數(shù)進行計算,并將結(jié)果重新賦值給saved_x。
                    # 類似于transformer 里面的KV Cache的,但是這里的concat是緯度不變的
                    # 對應(yīng) sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 這行代碼
                    saved_x = nn.emit_te(_te_concat_saved_x, saved_x, x)
                # 創(chuàng)建一個全為1的張量,形狀為(hidden_size,),數(shù)據(jù)類型為self.dtype,并將其賦值給ones。
                ones = nn.emit(relax.op.ones((hidden_size,), self.dtype))
                # 計算xk,根據(jù)時間混合參數(shù)self.time_mix_key和保存的狀態(tài)saved_x,使用加權(quán)求和的方式得到。
                # 其中,x和saved_x分別乘以self.time_mix_key和(ones - self.time_mix_key),然后相加。將計算結(jié)果賦值給xk。
                # 對應(yīng) kx = xx * k_mix + sx * (1 - k_mix) 這行代碼
                xk = nn.emit(x * self.time_mix_key + saved_x * (ones - self.time_mix_key))
                # 計算xr,根據(jù)時間混合參數(shù)self.time_mix_receptance和保存的狀態(tài)saved_x,使用加權(quán)求和的方式得到。
                # 其中,x和saved_x分別乘以self.time_mix_receptance和(ones - self.time_mix_receptance),然后相加。
                # 將計算結(jié)果賦值給xr。
                # 對應(yīng) rx = xx * r_mix + sx * (1 - r_mix)
                xr = nn.emit(
                    x * self.time_mix_receptance + saved_x * (ones - self.time_mix_receptance)
                )
                # # 如果上下文長度不為1,則執(zhí)行下面的操作。
                if not is_one(context_length):
                    # 調(diào)用nn.emit_te函數(shù),使用_te_get_last_x函數(shù)從x中獲取最后一個token對應(yīng)的tensor,并將結(jié)果重新賦值給x。
                    # 對應(yīng) xx[-1,:]
                    x = nn.emit_te(_te_get_last_x, x)
                # 斷言x的結(jié)構(gòu)信息(shape)的第一個維度為1。
                assert is_one(x.struct_info.shape[0])
                # 調(diào)用_store_state函數(shù),將x保存到state[offset]中,并將結(jié)果重新賦值給saved_x。
                # 對應(yīng):https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L921
                saved_x = _store_state(state[offset], x)

                # 將xr作為輸入,經(jīng)過sigmoid激活函數(shù)計算得到r。對應(yīng):r = torch.sigmoid(gemm(rx, rw))
                r = nn.emit(op.sigmoid(self.receptance(xr)))
                # 對應(yīng) vx = torch.square(torch.relu(gemm(kx, kw)))
                xv = nn.emit(op.square(op.nn.relu(self.key(xk))))

                return nn.emit(r * self.value(xv)), [saved_x]

        接下來對Attention部分的實現(xiàn)進行解析,注意這部分對應(yīng)的代碼在 https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L728-L747 。貼一下python代碼防止看錯位置產(chǎn)生疑問:

        if os.environ["RWKV_CUDA_ON"] == '1':
                @MyFunction
                def cuda_att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory):
                    T, C = x.shape
                    xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b)
                    sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
                    kx = xx * k_mix + sx * (1 - k_mix)
                    vx = xx * v_mix + sx * (1 - v_mix)
                    rx = xx * r_mix + sx * (1 - r_mix)

                    r = torch.sigmoid(gemm(rx, rw))
                    k = gemm(kx, kw, output_dtype=torch.float32)
                    v = gemm(vx, vw, output_dtype=torch.float32)
                    y, aa, bb, pp = cuda_wkv(T, aa.shape[0], t_decay, t_first, k, v, aa, bb, pp)
                    
                    out = gemm(r * y.to(x.dtype), ow)
                    return x + out, xx[-1,:], aa, bb, pp

        對應(yīng)mlc-llm RWKV Attention的代碼解析為:

        # 實現(xiàn)RWKV Attention,對應(yīng) https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L479
        class RWKV_Attention(nn.Module):
            # 初始化函數(shù),接受一個config對象和一個整數(shù)index作為參數(shù)。其中config是一個RWKVConfig類型的對象,index表示當前層的索引。
            def __init__(self, config: RWKVConfig, index: int) -> None:
                super().__init__()
                self.index = index
                self.dtype = config.dtype
                self.hidden_size = config.hidden_size
                # 創(chuàng)建一些可學習的參數(shù),如time_decay、time_first、time_mix_key等,這些參數(shù)會在模型的前向傳播中使用。
                self.time_decay = nn.Parameter(
                    (self.hidden_size,), dtype="float32", name=f"att_{index}_time_decay"
                )
                self.time_first = nn.Parameter(
                    (self.hidden_size,), dtype="float32", name=f"att_{index}_time_first"
                )
                self.time_mix_key = nn.Parameter(
                    (self.hidden_size,), dtype=config.dtype, name=f"att_{index}_time_mix_k"
                )
                self.time_mix_value = nn.Parameter(
                    (self.hidden_size,), dtype=config.dtype, name=f"att_{index}_time_mix_v"
                )
                self.time_mix_receptance = nn.Parameter(
                    (self.hidden_size,), dtype=config.dtype, name=f"att_{index}_time_mix_r"
                )
                # 前向傳播用到的線性層
                self.key = Linear(
                    self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False
                )
                self.value = Linear(
                    self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False
                )
                self.receptance = Linear(
                    self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False
                )
                self.output = Linear(
                    self.hidden_size, self.hidden_size, dtype=config.dtype, bias=False
                )

            # 前向傳播函數(shù),接受輸入張量x和狀態(tài)張量state作為參數(shù),并返回輸出張量
            def forward(self, x: Expr, state: Expr) -> Expr:
                # Load current state
                # 定義了一些局部變量,如ones、index、hidden_size、context_length等。
                ones = nn.emit(relax.op.ones((self.hidden_size,), self.dtype))
                index = self.index
                hidden_size = self.hidden_size
                context_length = x.struct_info.shape[0]
                bb = relax.BlockBuilder.current()

                # _load_state函數(shù)從state中加載保存的狀態(tài),賦值給saved_a、saved_b、saved_p和saved_x。
                saved_a = _load_state(state[index * 5 + State.ATT_A], hidden_size, "float32")
                saved_b = _load_state(state[index * 5 + State.ATT_B], hidden_size, "float32")
                saved_p = _load_state(state[index * 5 + State.ATT_P], hidden_size, "float32")
                saved_x = _load_state(state[index * 5 + State.ATT_X], hidden_size, self.dtype)
                
                # 調(diào)用nn.emit_te函數(shù),將saved_x和x作為參數(shù)傳遞給
                # _te_concat_saved_x函數(shù)進行計算,并將結(jié)果重新賦值給saved_x。
                # 對應(yīng) sx = torch.cat((sx.unsqueeze(0), xx[:-1,:]))
                if not is_one(context_length):
                    saved_x = nn.emit_te(_te_concat_saved_x, saved_x, x)

                # 對應(yīng) kx = xx * k_mix + sx * (1 - k_mix)
                xk = nn.emit(x * self.time_mix_key + saved_x * (ones - self.time_mix_key))
                # 對應(yīng) vx = xx * v_mix + sx * (1 - v_mix)
                xv = nn.emit(x * self.time_mix_value + saved_x * (ones - self.time_mix_value))
                # 對應(yīng) rx = xx * r_mix + sx * (1 - r_mix)
                xr = nn.emit(
                    x * self.time_mix_receptance + saved_x * (ones - self.time_mix_receptance)
                )

                # 對應(yīng) r = torch.sigmoid(gemm(rx, rw))
                r = nn.emit(op.sigmoid(self.receptance(xr)))
                # 對應(yīng) k = gemm(kx, kw, output_dtype=torch.float32)
                k = nn.emit(op.astype(self.key(xk), "float32"))
                # 對應(yīng) v = gemm(vx, vw, output_dtype=torch.float32)
                v = nn.emit(op.astype(self.value(xv), "float32"))

                # 這部分對應(yīng) y, aa, bb, pp = cuda_wkv(T, aa.shape[0], t_decay, t_first, k, v, aa, bb, pp)
                # 這里的 create_wkv_func 在上面已經(jīng)解析了
                gv = bb.add_func(create_wkv_func(hidden_size, "float32", self.dtype), "wkv")
                ret = nn.emit(
                    relax.call_tir(
                        gv,
                        [k, v, self.time_decay, self.time_first, saved_a, saved_b, saved_p],
                        [
                            R.Tensor((context_length, hidden_size), self.dtype), # 對應(yīng)wkv
                            R.Tensor((1, hidden_size), "float32"), # 對應(yīng)out_a
                            R.Tensor((1, hidden_size), "float32"), # 對應(yīng)out_b
                            R.Tensor((1, hidden_size), "float32"), # 對應(yīng)out_p
                        ],
                    )
                )
                if not is_one(context_length):
                    # 對應(yīng) xx[-1,:]
                    x = nn.emit_te(_te_get_last_x, x)

                assert is_one(x.struct_info.shape[0])
                saved_x = _store_state(state[self.index * 5 + State.ATT_X], x)
                saved_a = _store_state(state[self.index * 5 + State.ATT_A], ret[1])
                saved_b = _store_state(state[self.index * 5 + State.ATT_B], ret[2])
                saved_p = _store_state(state[self.index * 5 + State.ATT_P], ret[3])

                # 需要注意一下,python代碼里面的 return x + out, xx[-1,:], aa, bb, pp
                # 這里的 x + out被放在attention外面做了,因為這里的x已經(jīng)是被修改之后好的結(jié)果而不是原始的x
                return nn.emit(self.output(r * ret[0])), [
                    saved_x,
                    saved_a,
                    saved_b,
                    saved_p,
                ]

        接著解析一下RWKVLayer的實現(xiàn),請注意下面的最后一行代碼的解釋:

        class RWKVLayer(nn.Module):
            # 初始化函數(shù),接受一個config對象和一個整數(shù)index作為參數(shù)。其中config是一個RWKVConfig類型的對象,index表示層的索引。
            def __init__(self, config: RWKVConfig, index: int) -> None:
                super().__init__()
                # 如果index為0,創(chuàng)建一個RWKV_LayerNorm對象pre_ln,用于對輸入進行Layer Normalization操作。
                if index == 0:
                    self.pre_ln = RWKV_LayerNorm(
                        config.hidden_size,
                        config.dtype,
                        eps=config.layer_norm_epsilon,
                        name_prefix="pre_ln",
                    )
                # 創(chuàng)建兩個RWKV_LayerNorm對象,分別命名為ln1和ln2,
                # 用于對注意力機制和前饋神經(jīng)網(wǎng)絡(luò)的輸出進行Layer Normalization操作。
                self.ln1 = RWKV_LayerNorm(
                    config.hidden_size,
                    config.dtype,
                    eps=config.layer_norm_epsilon,
                    name_prefix=f"att_{index}",
                )
                self.ln2 = RWKV_LayerNorm(
                    config.hidden_size,
                    config.dtype,
                    eps=config.layer_norm_epsilon,
                    name_prefix=f"ffn_{index}",
                )
                # 創(chuàng)建一個RWKV_Attention對象attention,用于實現(xiàn)注意力機制。
                self.attention = RWKV_Attention(config, index)
                # 創(chuàng)建一個RWKV_FFN對象feed_forward,用于實現(xiàn)前饋神經(jīng)網(wǎng)絡(luò)。
                self.feed_forward = RWKV_FFN(config, index)
                self.rescale_every = config.rescale_every
                self.dtype = config.dtype
                self.index = index

            # 前向傳播函數(shù),接受輸入張量x和狀態(tài)張量state作為參數(shù),并返回輸出張量和更新后的狀態(tài)列表。
            def forward(self, x: Expr, state: Expr) -> Tuple[Expr, List[Expr]]:
                # 如果index為0,則將輸入張量x傳入pre_ln進行Layer Normalization操作。
                if self.index == 0:
                    x = self.pre_ln(x)
                # 將經(jīng)過ln1的輸入張量x和狀態(tài)張量state傳入attention進行計算,得到注意力機制的輸出att和更新后的狀態(tài)列表att_state。
                att, att_state = self.attention(self.ln1(x), state)
                # 將輸入張量x和注意力機制的輸出att相加,并將結(jié)果賦值給x。
                x = nn.emit(x + att)
                # 將經(jīng)過ln2的輸入張量x和狀態(tài)張量state傳入feed_forward進行計算,得到前饋神經(jīng)網(wǎng)絡(luò)的輸出ffn和更新后的狀態(tài)列表ffn_state。
                ffn, ffn_state = self.feed_forward(self.ln2(x), state)
                # 將輸入張量x和前饋神經(jīng)網(wǎng)絡(luò)的輸出ffn相加,并將結(jié)果賦值給x。
                x = nn.emit(x + ffn)
                # 如果滿足self.rescale_every > 0且(self.index + 1) % self.rescale_every == 0,則對輸入張量x進行縮放操作。
                if self.rescale_every > 0 and (self.index + 1) % self.rescale_every == 0:
                    x = nn.emit(x / relax.const(2, dtype=self.dtype))
                # 返回輸出張量x和注意力機制和前饋神經(jīng)網(wǎng)絡(luò)的更新后的狀態(tài)列表的拼接。
                return x, att_state + ffn_state

        注意這里的attn_state是[saved_x, saved_a, saved_b, saved_p,] ,然后ffn_state是[saved_x],注意這兩個x是不一樣的,這5個狀態(tài)也和本節(jié)開頭的class State的成員定義一致。

        接下來對RWKV模型定義進行了解析:

        # 該代碼是一個自定義的PyTorch模型類RWKVModel,繼承自nn.Module
        class RWKVModel(nn.Module):
            # 初始化函數(shù),接受一個config對象作為參數(shù)。其中config是一個RWKVConfig類型的對象。
            def __init__(self, config: RWKVConfig) -> None:
                super().__init__()
                # 創(chuàng)建一個RWKV_Embedding對象embeddings,用于實現(xiàn)輸入的嵌入操作。
                self.embeddings = RWKV_Embedding(
                    num_embeddings=config.vocab_size,
                    embedding_dim=config.hidden_size,
                    dtype=config.dtype,
                )
                # 創(chuàng)建一個ModuleList對象blocks,其中包含了config.num_hidden_layers個RWKVLayer對象,
                # 每個對象的索引從0到config.num_hidden_layers-1。
                self.blocks = ModuleList(
                    [RWKVLayer(config, i) for i in range(config.num_hidden_layers)]
                )
                # 創(chuàng)建一個RWKV_LayerNorm對象ln_out,用于對輸出進行Layer Normalization操作。
                self.ln_out = RWKV_LayerNorm(
                    config.hidden_size,
                    config.dtype,
                    eps=config.layer_norm_epsilon,
                    name_prefix="out_ln",
                )
                self.hidden_size = config.hidden_size
                self.dtype = config.dtype

            # 前向傳播函數(shù),接受輸入張量input_ids和狀態(tài)張量state作為參數(shù),并返回輸出張量和更新后的狀態(tài)列表。
            def forward(self, input_ids: Expr, state: Expr) -> Tuple[Expr, List[Expr]]:
                # 將輸入張量input_ids傳入embeddings進行嵌入操作,得到隱藏狀態(tài)張量hidden_states。
                hidden_states = self.embeddings(input_ids)
                # 創(chuàng)建一個空列表states,用于存儲每個RWKVLayer對象的更新后的狀態(tài)列表。
                states = []
                # 遍歷blocks中的每個RWKVLayer對象,將隱藏狀態(tài)張量hidden_states和狀態(tài)張量state傳入
                # 每個RWKVLayer對象的前向傳播函數(shù)進行計算,得到更新后的隱藏狀態(tài)張量和更新后的狀態(tài)列表,
                # 并將更新后的狀態(tài)列表添加到states中。
                for _, layer in enumerate(self.blocks):
                    hidden_states, layer_states = layer(hidden_states, state)
                    states += layer_states
                # 獲取隱藏狀態(tài)張量的上下文長度context_length。
                context_length = hidden_states.struct_info.shape[0]
                # 如果context_length不為1,則調(diào)用_te_get_last_x函數(shù)獲取最后一個token對應(yīng)的張量。
                if not is_one(context_length):
                    hidden_states = nn.emit_te(_te_get_last_x, hidden_states)
                # 將隱藏狀態(tài)張量傳入ln_out進行Layer Normalization操作。
                hidden_states = self.ln_out(hidden_states)
                # 返回輸出隱藏狀態(tài)張量和所有RWKVLayer對象的更新后的狀態(tài)列表。
                return hidden_states, states

        # 該代碼是一個自定義的PyTorch模型類RWKVForCausalLM,繼承自nn.Module。
        class RWKVForCausalLM(nn.Module):
            # 初始化函數(shù),接受一個config對象作為參數(shù)。其中config是一個RWKVConfig類型的對象。
            def __init__(self, config: RWKVConfig):
                # 創(chuàng)建一個RWKVModel對象rwkv,用于實現(xiàn)序列模型的計算。
                self.rwkv = RWKVModel(config)
                # 創(chuàng)建一個Linear對象head,用于將隱藏狀態(tài)映射到詞匯表大小的輸出空間。
                self.head = Linear(
                    config.hidden_size, config.vocab_size, dtype=config.dtype, bias=False
                )
                self.vocab_size = config.vocab_size
                ############ End ############

            # 前向傳播函數(shù),接受輸入張量input_ids和狀態(tài)張量state作為參數(shù),并返回預(yù)測的logits和更新后的kv cache。
            def forward(
                self,
                input_ids: relax.Expr,
                state: relax.Expr,
            )
        :

                # 將輸入張量input_ids和狀態(tài)張量state傳入rwkv對象的前向傳播函數(shù)進行計算,
                # 得到更新后的隱藏狀態(tài)張量hidden_states和key-value緩存key_value_cache。
                hidden_states, key_value_cache = self.rwkv(input_ids, state)
                # 將隱藏狀態(tài)張量hidden_states傳入head進行線性映射操作,得到logits。
                logits = nn.emit(self.head(hidden_states))
                # 對logits進行形狀重塑,將其reshape為形狀為(1, 1, self.vocab_size)的張量。
                logits = nn.emit(op.reshape(logits, (11, self.vocab_size)))
                # 如果logits的數(shù)據(jù)類型不是float32,則將其轉(zhuǎn)換為float32類型。
                if logits.struct_info.dtype != "float32":
                    logits = nn.emit(relax.op.astype(logits, "float32"))

                return logits, key_value_cache

        解下是一個根據(jù)參數(shù)的名字確定量化參數(shù)類型的函數(shù):

        # 該代碼定義了一個函數(shù)get_param_quant_kind,用于根據(jù)參數(shù)名稱和參數(shù)信息確定參數(shù)的量化類型。
        def get_param_quant_kind(
            name: str, param_info: relax.TensorStructInfo
        )
         -> ParamQuantKind:

            # 如果參數(shù)名稱以"embeddings.weight"結(jié)尾,返回ParamQuantKind.embedding_table表示該參數(shù)是嵌入表的權(quán)重。
            if name.endswith("embeddings.weight"):
                return ParamQuantKind.embedding_table
            # 如果參數(shù)名稱為"head.weight",返回ParamQuantKind.final_fc_weight表示該參數(shù)是最后一個全連接層的權(quán)重。
            elif name == "head.weight":
                return ParamQuantKind.final_fc_weight
            # 如果參數(shù)的維度為2且名稱以".weight"結(jié)尾,返回ParamQuantKind.linear_weight表示該參數(shù)是線性層的權(quán)重。
            elif param_info.ndim == 2 and name.endswith(".weight"):
                return ParamQuantKind.linear_weight
            else:
                return ParamQuantKind.others

        上面已經(jīng)完成了RWKV模型的定義,接下來是定義幾個相關(guān)的TIR函數(shù)并定義一個最終的TIR模型獲取函數(shù)。這里對創(chuàng)建prefill和decode的create_func函數(shù)以及最終的TIR模型獲取函數(shù)get_model進行解析:

        由于字數(shù)被公眾號限制了,請在知乎文章查看這部分,https://zhuanlan.zhihu.com/p/658354795

        自此,我們基本就有了搭建RWKV模型的全部流程,說白了就是用TVM的Relax語言手動一對一的把PyTorch實現(xiàn)翻譯過去。

        0x3. Transform舉例

        在mlc-llm有一些圖層的優(yōu)化,在 https://github.com/BBuf/mlc-llm-code-analysis/tree/main/mlc_llm/transform 這個文件里面,我們對其中的一些優(yōu)化Pass做一下解析。

        0x3.1 rewrite attention

        代碼如下:

        # 導入了TVM的relax模塊中的一些函數(shù)和類,以及TVM的script模塊中的relax別名。
        from tvm.relax.dpl import PatternContext, is_const, is_op, rewrite_call, wildcard
        from tvm.script import relax as R

        # 定義了一個名為rewrite_attention的函數(shù),接收一個參數(shù)f。
        def rewrite_attention(f):
            # 使用wildcard()創(chuàng)建了三個通配符,分別賦值給Q、K和V。
            Q = wildcard()
            K = wildcard()
            V = wildcard()

            # 使用is_op()函數(shù)創(chuàng)建了三個操作模式,分別對應(yīng)Q、K和V的維度重排操作,并將結(jié)果分別賦值給Q_BNSH、K_BNSH和V_BNSH。
            Q_BNSH = is_op("relax.permute_dims")(Q)
            K_BNSH = is_op("relax.permute_dims")(K)
            V_BNSH = is_op("relax.permute_dims")(V)

            # 使用is_op()函數(shù)創(chuàng)建了一個操作模式,對應(yīng)K_BNSH的維度重排操作,并將結(jié)果賦值給K_BNSH_T。
            K_BNSH_T = is_op("relax.permute_dims")(K_BNSH)

            # 使用is_op()函數(shù)創(chuàng)建了一系列操作模式,對應(yīng)矩陣乘法、除法、最大值、最小值、softmax以及另一個矩陣乘法操作。
            # 這些操作模式(Attention)根據(jù)之前定義的通配符和常數(shù)匹配不同的計算圖節(jié)點。
            matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T)
            divide = is_op("relax.divide")(matmul1, is_const())
            max = is_op("relax.maximum")(divide, is_const())
            min = is_op("relax.minimum")(max, wildcard())
            softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min))
            matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH)

            # 使用is_op()函數(shù)創(chuàng)建了一個操作模式,對應(yīng)matmul2的維度重排操作,并將結(jié)果賦值給pattern。
            pattern = is_op("relax.permute_dims")(matmul2)

            # 定義了一個名為callback的回調(diào)函數(shù),接收兩個參數(shù)_和matchings。
            # 該回調(diào)函數(shù)使用R.nn.attention函數(shù)構(gòu)建一個新的計算圖節(jié)點,并使用matchings字典中的匹配結(jié)果來填充該節(jié)點的參數(shù)。
            def callback(_, matchings):
                return R.nn.attention(
                    matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight"
                )

            # 使用rewrite_call函數(shù)將pattern、callback和輸入的計算圖f傳遞給它,以便在計算圖中應(yīng)用模式匹配和重寫。
            # 最后,將重寫后的計算圖返回。
            return rewrite_call(pattern, callback, f)

        雖然沒有完全看懂這里的操作比如max和min的含義,但是從后面的callback_可以猜測出這里的Pass就是把打散的Self Attention模塊融合為一個relax.nn.attention操作。在cuda后端,如果支持了cutlass,那么relax.nn.attention操作就對應(yīng)了Flash Attention。

        0x3.2 Transpose MatMul

        代碼實現(xiàn)解析如下:

        # 這段代碼定義了一個名為TransposeMatmulCodeGenerator的類,該類繼承自relax.PyExprMutator。
        # 通過@relax.expr_functor.mutator裝飾器將該類聲明為一個表達式重寫器。
        @relax.expr_functor.mutator
        class TransposeMatmulCodeGenerator(relax.PyExprMutator):
            def __init__(self, mod):
                super().__init__(mod)

            @staticmethod
            def pattern():
                # 定義了靜態(tài)方法pattern(),該方法返回一個描述模式的元組。
                # 通過使用通配符(wildcard())和操作模式(is_op())來匹配計算圖中的特定模式。
                # 在這個例子中,模式匹配了一個矩陣乘法操作中矩陣w的維度重排操作,并將匹配的結(jié)果保存在字典annotations中。
                w = wildcard()
                x = wildcard()
                wT = is_op("relax.permute_dims")(w)
                o = is_op("relax.matmul")(x, wT)
                annotations = {"o": o, "w": w, "x": x, "wT": wT}

                # 定義了內(nèi)部函數(shù)_check(),用于檢查模式匹配的結(jié)果是否滿足特定的條件。
                # 在這個例子中,檢查了維度重排操作的維度數(shù)和軸的順序是否正確。
                def _check(context: relax.transform.PatternCheckContext) -> bool:
                    transpose_call = context.annotated_expr["wT"]
                    ndim = transpose_call.args[0].struct_info.ndim
                    if ndim == -1:
                        return False
                    if ndim == 2 and transpose_call.attrs.axes is None:
                        return True
                    axes = list(range(ndim))
                    axes[-1], axes[-2] = axes[-2], axes[-1]
                    return list(transpose_call.attrs.axes) == axes

                # 將匹配的計算圖節(jié)點、注解和檢查函數(shù)作為元組返回。
                return o, annotations, _check

            # 重寫了父類的visit_call_()方法,用于處理特定類型的計算圖節(jié)點。
            def visit_call_(self, call: relax.Call) -> relax.Expr:
                # 定義了一個變量out_dtype,用于保存輸出的數(shù)據(jù)類型。
                out_dtype = None

                # 定義了一個內(nèi)部函數(shù)te_transposed_matmul(),該函數(shù)實現(xiàn)了矩陣乘法的計算邏輯。
                def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor:
                    nonlocal out_dtype
                    # 將輸入張量 a 和 b 的形狀轉(zhuǎn)換為列表形式,分別保存在變量 a_shape 和 b_shape 中。
                    a_shape = list(a.shape)
                    b_shape = list(b.shape)
                    # 定義了兩個布爾變量 a_prepended 和 b_appended,用于標記是否在相應(yīng)的形狀的前面或后面添加了維度。
                    a_prepended = False
                    b_appended = False
                    # 如果輸入張量 a 的形狀為一維,則在其前面添加一個維度,將其形狀修改為 (1, original_shape)。
                    # 同樣地,如果輸入張量 b 的形狀為一維,則在其后面添加一個維度,將其形狀修改為 (original_shape, 1)。
                    if len(a_shape) == 1:
                        a_prepended = True
                        a_shape.insert(01)
                    if len(b_shape) == 1:
                        b_appended = True
                        b_shape.append(1)

                    # 比較 a_shape 和 b_shape 的長度,將結(jié)果保存在布爾變量 is_a_larger 中。
                    # offset 表示兩個形狀長度之差,用于后續(xù)處理。
                    is_a_larger = len(a_shape) > len(b_shape)
                    offset = (
                        len(a_shape) - len(b_shape)
                        if is_a_larger
                        else len(b_shape) - len(a_shape)
                    )

                    # 創(chuàng)建兩個 relax.Var 對象 a_relax 和 bT_relax,用于表示張量 a 和轉(zhuǎn)置后的張量 bT 的結(jié)構(gòu)信息。
                    # a_relax 的形狀和 a 的形狀相同,bT_relax 的形狀是 b 的形狀經(jīng)過維度互換后的結(jié)果。
                    a_relax = relax.Var("a", relax.TensorStructInfo(a.shape))
                    bT_shape = list(b.shape)
                    bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1]
                    bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape))
                    # 使用 relax.op.matmul() 方法對 a_relax 和 bT_relax 進行矩陣乘法運算。
                    # 然后,通過 self.builder_.normalize() 方法對結(jié)果進行歸一化處理,并獲取最終的輸出形狀。
                    output_shape = self.builder_.normalize(
                        relax.op.matmul(a_relax, bT_relax)
                    ).struct_info.shape

                    # 該函數(shù)接受可變數(shù)量的空間索引參數(shù) idx_spatial,
                    def matmul_compute(*idx_spatial):
                        # 并定義了一個名為 k 的規(guī)約軸(reduce axis),其范圍為 0 到 a_shape[-1]。
                        k = te.reduce_axis((0, a_shape[-1]), name="k")

                        # 定義了一個名為 multiply_compute 的內(nèi)部函數(shù),用于計算乘法操作時的索引。
                        def multiply_compute(idx_reduce):
                            a_indices = []
                            b_indices = []

                            # 根據(jù) is_a_larger 的值,將 idx_spatial 中的索引分配給 a_indices 或 b_indices,用于處理形狀長度差異的維度。
                            for i in range(offset):
                                if is_a_larger:
                                    a_indices.append(idx_spatial[i])
                                else:
                                    b_indices.append(idx_spatial[i])
                            for i in range(
                                offset, len(output_shape) - (2 - a_prepended - b_appended)
                            ):
                                # 根據(jù)維度的相等性,將適當?shù)乃饕砑拥?nbsp;a_indices 和 b_indices 中。
                                # 如果維度不相等或無法確定是否相等,則將索引設(shè)為 0 或保持不變。
                                a_dim = a_shape[i if is_a_larger else i - offset]
                                b_dim = b_shape[i if not is_a_larger else i - offset]
                                dim_equal = a_dim == b_dim
                                if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0:
                                    a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1
                                    b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1
                                    a_indices.append(0 if a_dim_is_one else idx_spatial[i])
                                    b_indices.append(0 if b_dim_is_one else idx_spatial[i])
                                else:
                                    a_indices.append(idx_spatial[i])
                                    b_indices.append(idx_spatial[i])

                            # 在乘法操作的索引中添加規(guī)約軸 idx_reduce,并根據(jù) a_prepended 和 b_appended 的值,
                            # 將適當?shù)乃饕砑拥?nbsp;a_indices 和 b_indices 中。
                            if not a_prepended:
                                a_indices.append(idx_spatial[-2 + b_appended])
                            a_indices.append(idx_reduce)
                            if not b_appended:
                                b_indices.append(idx_spatial[-1])
                            b_indices.append(idx_reduce)

                            # 根據(jù) out_dtype 的值,選擇是否進行數(shù)據(jù)類型轉(zhuǎn)換,并返回乘法操作的結(jié)果。
                            dtype = out_dtype
                            if dtype != "":
                                return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype)
                            return a(*a_indices) * b(*b_indices)

                        # 在縮減軸 k 上對 multiply_compute 的結(jié)果進行求和操作。
                        return te.sum(multiply_compute(k), axis=k)

                    # 使用 te.compute() 函數(shù)計算最終的輸出,其中使用一個 lambda 函數(shù)將輸入索引傳遞給 matmul_compute 函數(shù),
                    # 并將結(jié)果命名為 "NT_matmul"。整個計算過程將根據(jù) output_shape 進行執(zhí)行。
                    return te.compute(
                        output_shape,
                        lambda *idx: matmul_compute(*idx),  # pylint: disable=unnecessary-lambda
                        name="NT_matmul",
                    )

                # 首先,檢查函數(shù)調(diào)用的操作符 call.op 是否是 relax.GlobalVar 類型。如果是,獲取與該操作符對應(yīng)的函數(shù)對象,
                # 并檢查函數(shù)的屬性中是否包含鍵 "Composite",且其值為 "transpose_matmul_fuse"。
                if isinstance(call.op, relax.GlobalVar):
                    function = self.builder_.get()[call.op]
                    if (
                        "Composite" in function.attrs
                        and function.attrs["Composite"] == "transpose_matmul_fuse"
                    ):
                        # 將函數(shù)的返回類型 function.ret_struct_info.dtype 賦值給變量 out_dtype
                        out_dtype = function.ret_struct_info.dtype
                        # 然后調(diào)用 self.builder_.call_te() 方法,傳遞 te_transposed_matmul 函數(shù)作為參數(shù),
                        # 以及調(diào)用的參數(shù) call.args[1] 和 call.args[0],并指定 primfunc_name_hint 為 "NT_matmul"。
                        return self.builder_.call_te(
                            te_transposed_matmul,
                            call.args[1],
                            call.args[0],
                            primfunc_name_hint="NT_matmul",
                        )

                return super().visit_call_(call)

        # 使用 @tvm.transform.module_pass 裝飾器定義了一個名為 FuseTransposeMatmul 的類,
        # 并指定了優(yōu)化級別 opt_level=0 和 pass 的名稱為 "FuseTransposeMatmul"。
        @tvm.transform.module_pass(opt_level=0, name="FuseTransposeMatmul")
        class FuseTransposeMatmul:
            # 定義了 transform_module 方法,接受一個名為 mod 的 IRModule 對象和
            # tvm.transform.PassContext 對象作為參數(shù),并返回一個 IRModule 對象。
            def transform_module(
                self, mod: IRModule, ctx: tvm.transform.PassContext
            )
         -> IRModule:

                # 通過調(diào)用 relax.transform.FuseOpsByPattern 并傳遞一個包含單個模式元組的列表,
                # 對模塊 mod 進行融合的轉(zhuǎn)置矩陣乘法操作。
                mod = relax.transform.FuseOpsByPattern(
                    [("transpose_matmul_fuse", *TransposeMatmulCodeGenerator.pattern())]
                )(mod)

                # 創(chuàng)建一個名為 transpose_matmul_codegen 的 TransposeMatmulCodeGenerator 對象,
                # 并對模塊中的每個函數(shù)進行遍歷。如果函數(shù)是 relax.Function 類型,則調(diào)用 transpose_matmul_codegen.visit_expr 
                # 方法對函數(shù)進行轉(zhuǎn)置矩陣乘法代碼生成,并通過 transpose_matmul_codegen.builder_.update_func 方法更新函數(shù)。
                transpose_matmul_codegen = TransposeMatmulCodeGenerator(mod)
                for gv in mod.functions:
                    func = mod[gv]
                    if not isinstance(func, relax.Function):
                        continue
                    func = transpose_matmul_codegen.visit_expr(func)
                    transpose_matmul_codegen.builder_.update_func(gv, func)

                # 返回轉(zhuǎn)置矩陣乘法代碼生成器的 builder 對象中的模塊。
                return transpose_matmul_codegen.builder_.get()

        這個Pass將Transpose算子和一個MatMul算子替換為一個TE表達式的實現(xiàn)來達到融合算子的目的。

        除了上面2種Pass,MLC-LLM還有不少的圖變換Pass,這篇文章就不一一去解析了,大多數(shù)優(yōu)化的目的都是匹配某種Pattern然后用更優(yōu)秀的算子去完成計算。

        量化策略這一塊就不在這篇文章解析了。

        0x4. MLC-LLM優(yōu)缺點個人評價和期待

        0x4.1 優(yōu)點

        • Tune Free。mlc-llm不需要用TVM的AutoTVM/Ansor等等程序去執(zhí)行算子搜索過程,對跨平臺部署是比原始的TVM搭建的模型更清真的。
        • TIR的語法很大程度靠近了PyTorch的API,使得用戶在模型搭建部分不會很困難。
        • 文檔寫得不錯,跟隨教程基本可以完成大多數(shù)平臺的模型部署,并且單Batch下的吞吐和延遲表現(xiàn)都是不錯的。

        0x4.2 缺點

        • 不支持從onnx或者huggingface模型直接轉(zhuǎn)換出TIR,手工實現(xiàn)模型的時候需要相當多的先驗知識,比如在上面的RWKV模型中如果有自定義的cuda kernel,那么這個模型的實現(xiàn)可能只能全權(quán)委托給mlc-ai社區(qū)的核心開發(fā)人員了。
        • KV Cache開的是max_sequence_length這么長,顯然會有顯存的浪費,Serving的時候極限情況下可以服務(wù)的用戶數(shù)量應(yīng)該比VLLM/TGI等要?。?
        • CUDA后端Decoding的Attention我看起來好像還是會用Flash Attention?也許是我看錯了,這條暫時存疑。
        • 在RWKV模型實現(xiàn)里,看到Batch維度寫死為1了,應(yīng)該不支持動態(tài)Batch?這樣對于啟真實服務(wù)來說會有一些限制。

        0x4.3 期待

        • 如果短期內(nèi)能讓一個對TVM只有輕度依賴的社區(qū)開發(fā)者新增一個新的模型。
        • 如果模型存在自定義CUDA Kernel,需要一個詳細的教程來指引。
        • 模型逐層打印來debug精度缺一個教程。
        • Paged Attention類似策略的引入。
        • 動態(tài)Batch的支持。

        暫時就想到這些,歡迎斧正。


        為了感謝讀者的長期支持,今天我們將送出三本由 北京理工大學出版社 提供的:《深度學習與計算機視覺:核心算法與應(yīng)用》 。點擊下方抽獎助手參與抽獎。沒抽到的小伙伴可以使用下方鏈接購買。

        《深度學習與計算機視覺:核心算法與應(yīng)用》抽獎鏈接


        瀏覽 776
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        評論
        圖片
        表情
        推薦
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
          
          

            1. 男女做爱免费在线观看 | 亚洲专区视频 | 国产精品禁久久 | 日欧成人AV | 丁香五月天激情网 |