1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        大佬是怎么優(yōu)雅實現(xiàn)矩陣乘法的?

        共 2047字,需瀏覽 5分鐘

         ·

        2021-06-25 08:22

        作者丨立交橋跳水冠軍
        來源丨h(huán)ttps://zhuanlan.zhihu.com/p/383115932
        編輯丨GiantPandaCV


        今天一翻朋友圈,發(fā)現(xiàn)好多人轉(zhuǎn)發(fā)一個業(yè)內(nèi)大佬寫的開源項目。內(nèi)容很簡單,就是在CPU上實現(xiàn)單精度矩陣乘法。看了一下,結(jié)果非常好:CPU的利用率很高。更可貴的是核心代碼只有很短不到200行。

        之前總覺得自己很了解高性能計算,無外乎就是“局部性+向量”隨便搞一搞。但是嘴上說說和實際實現(xiàn)自然有很大差別??赐炅舜罄械拇a覺得受益匪淺,在這里總結(jié)了一下,當作自己的讀書筆記了。

        最前面自然是要放項目鏈接,強烈推薦大家讀一讀源代碼:https://github.com/pigirons/sgemm_hsw

        =========================正文===============================

        問題描述:給定兩個矩陣,其shape分別為(m,k)和(k, 24),求矩陣相乘的結(jié)果。

        為了方便理解,這里直接把m和k弄了一個數(shù)值帶了進去。所以我們的問題如下:輸入是棕色矩陣A和藍色矩陣B,求紅色矩陣C

        我們知道一般矩陣乘法就是一堆循環(huán)的嵌套,這個也不例外。在代碼里,最外層結(jié)果是輸出矩陣的行遍歷。又因為會有向量化的操作,所以最終結(jié)果是:最外層的循環(huán)每次算4行輸出(PS:這里面的4是固定的,并不是我為了方便隨便設(shè)的)。

        就是下面的情況:

        現(xiàn)在我們拆開來看每輪循環(huán):我們每輪會算4行,24列的輸出。在這里,我們把輸出用12個向量寄存器表示。

        現(xiàn)在可以隱約看出來為什么大佬要固定24這個數(shù)字了:因為ymm寄存器只有16個,我們又希望行數(shù)可以比較整,那么我們每次處理4行比較合適,處理4行的話,每行可以有16/4=4個寄存器。但是我們要做向量運算的話,那我們一定又要有向量寄存器當作運算符,所以我們不能把這16個寄存器都用來存output。所以權(quán)衡一下,那我們每行用3個寄存器好了,這樣總共12個寄存器存結(jié)果,剩下4個用來搞搞計算。因為ymm是256bit的,可以存8個float類型,所以我們每列就應(yīng)該是24

        確定了計算的目標,下面我們繼續(xù)更進一步,來看我們在每個內(nèi)存循環(huán)都要做什么。還記得我們之前剩了4個ymm寄存器么?現(xiàn)在我們把它們都利用上:先來思考下我們能不能直接在A矩陣用ymm?如果用的話,那么我們會把A矩陣一行的連續(xù)數(shù)據(jù)存到一起。這些數(shù)據(jù)會和誰運算呢?是B的一列數(shù)據(jù),也就是圖中黑色的部分。一般來說我們假設(shè)矩陣都是列連續(xù)的。那么訪問黑色的部分,locality就會很差:我們要把這些數(shù)字一個一個讀出來,塞到一個ymm里面和A的ymm進行運算。

        用排除法,我們別無選擇,只能把ymm用到B上面。B也是24列,我們用3個ymm就存下了。還剩一個,我們先把A的第一行第一列的數(shù)字讀出來,把它復(fù)制8份拓展成一個ymm,然后和這三個B的ymm作element-wise的乘法,把結(jié)果累加到y(tǒng)mm0~ymm2里。

        現(xiàn)在發(fā)現(xiàn)這個算法的精妙了么?對的!他正好把16個ymm都用上了,一個不多一個不少

        之后我們該干嘛?其實有很多選擇,比如我們把ymm12~ymm14往下移動一行,和第一行第二列的數(shù)字做乘法,如下圖:

        (?? 這個是低效的做法)
        正確性上來說,上面的做法沒問題。但我們來看看下圖里大佬是怎么做的:

        相比于之前我們說的循環(huán)到A的第一行第二列,大佬循環(huán)到了第二行第一列:在這種情況下我們只需要重新構(gòu)造ymm15,原來的ymm12~ymm14完全都不需要變,不需要讀新的數(shù)值,只需要改變輸出位置,從原來寫到y(tǒng)mm0~ymm2變成了ymm3~ymm5。但因為是寫寄存器而非內(nèi)存,所以都一樣。

        說到這兒,大概也把循環(huán)捋清楚了:最內(nèi)層是按照A的列來迭代:(1)把A的第一行第一列讀出來構(gòu)造ymm15做計算,(2)把A的第二行第一列讀出來構(gòu)造ymm15做計算。。。。一直讀到A的第四行第一列(為什么是第四行?因為我們輸出是四行的寄存器),然后開始讀A的第一行第二列構(gòu)造ymm,然后讀A的第二行第二列構(gòu)造ymm。。。

        總結(jié):

        (1)寫并行計算,感覺就像在下國際象棋:你有很多種走法,這些走法都合法,但是最優(yōu)的只有一種。

        (2)實際上寫高性能的程序就是在湊數(shù):在這個代碼里,我們根據(jù)體系結(jié)構(gòu)里ymm的寬度和ymm的寄存器個數(shù),推導(dǎo)出我們輸出矩陣每行得有24列。然后又繼續(xù)湊湊湊,得到了4步的步長的循環(huán)。雖然都是湊數(shù),但是大佬的代碼湊的很好:每一個ymm都被利用到了,這就是人家的水平


        - The End -

        GiantPandaCV

        長按二維碼關(guān)注我們

        本公眾號專注:

        1. 技術(shù)分享;

        2. 學(xué)術(shù)交流;

        3. 資料共享。

        歡迎關(guān)注我們,一起成長!



        瀏覽 66
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        評論
        圖片
        表情
        推薦
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            国产成人精品免费视频大全办公室 | chinese粉嫩露出vide | 夜夜嗨AⅤ一区二区三区 | 六月婷色 | 三上悠亚在线一区 | 任我爽在线视频 | 男人舔女人下面高潮 | 免费的人成无码大片在线观看 | 男人操女人的下面 | 韩国成人精品三级 |