在GPU上加速RWKV6模型的Linear Attention計算
共 48473字,需瀏覽 97分鐘
·
2024-05-05 23:08
精簡版:經(jīng)過一些profile發(fā)現(xiàn)flash-linear-attention中的rwkv6 linear attention算子的表現(xiàn)比RWKV-CUDA中的實現(xiàn)性能還要更好,然后也看到了繼續(xù)優(yōu)化triton版本kernel的線索。接著還分析了一下rwkv6 cuda kernel的幾次開發(fā)迭代以此說明對于不懂cuda以及平時無法從擅長cuda的大佬身上取經(jīng)的人比如我就完全放棄cuda了,可以深入學一下和使用triton,這已經(jīng)完全足夠了(除了會寫之外還可以了解內(nèi)部的MLIR相關的編譯器知識,可以對GPU體系架構理解得更加深刻)。
0x0. 前言
本文主要講一些看到的RWKV 6模型的Linear Attention模塊推理加速方法,在這篇博客中暫不涉及對kernel的深入解析。首先,flash-linear-attention(https://github.com/sustcsonglin/flash-linear-attention )這個倉庫旨在對各種線性Attention架構進行工程加速,例如RetNet,GLA,Manba,RWKV6(2024年4月引入)。它使用Triton來編寫代碼,并針對不同的線性Transformer架構使用不同的優(yōu)化方式。例如對于RWKV 6就采用在時間維度進行kernel fuse的方式來加速。其次,RWKV-CUDA是RWKV系列模型迭代中針對Linear Attention模塊的改進開發(fā)的自定義高性能cuda kernel(https://github.com/BlinkDL/RWKV-CUDA)。flash-rwkv(https://github.com/BBuf/flash-rwkv)倉庫在RWKV-CUDA的最優(yōu)性能算子的基礎上進行了封裝,提供了rwkv5_cuda_linear_attention和rwkv6_cuda_linear_attention兩個接口方便在HuggingFace模型實現(xiàn)中直接加速推理的prefill階段速度。
本篇文章主要會對比一下RWKV6 Linear Attention模塊的naive實現(xiàn)(pure pytorch),RWKV-CUDA的RWKV6 Linear Attention cuda kernel實現(xiàn)(用flash-rwkv提供的接口進行測試),flash-linear-attention里的RWKV6 Linear Attention實現(xiàn)。來說明Triton已經(jīng)成為目前LLM時代開發(fā)的一個趨勢,小伙伴們確實可以學起來。目前我對Triton的了解也非常少而且很膚淺,后續(xù)也會持續(xù)學習和實踐。
下面列舉本文相關的資料,如果你想對RWKV 6這個架構有一些了解可以閱讀后面三個鏈接,當然不閱讀也不影響閱讀本文:
-
https://github.com/sustcsonglin/flash-linear-attention -
https://mp.weixin.qq.com/s/Vol_LeHVHDAwE1pWTHOl2Q -
梳理RWKV 4,5(Eagle),6(Finch)架構的區(qū)別以及個人理解和建議 -
RWKV 模型保姆級微調(diào)教程
另外,本文使用了PyTorch Profiler TensorBoard 插件來做程序的性能分析,感興趣的小伙伴可以在系統(tǒng)調(diào)優(yōu)助手,PyTorch Profiler TensorBoard 插件教程 獲取到詳細的教程。
0x1. 瓶頸是什么
RWKV6 推理 Prefill 階段的性能瓶頸就在于RWKV6模型代碼中的rwkv6_linear_attention_cpu函數(shù):https://huggingface.co/RWKV/rwkv-6-world-1b6/blob/main/modeling_rwkv6.py#L54-L104
def rwkv6_linear_attention(
training,
receptance,
key,
value,
time_decay,
time_first,
state,
):
no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
# Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
# in this case).
one_token = key.size(1) == 1
if no_cuda or one_token:
return rwkv6_linear_attention_cpu(
receptance, key, value, time_decay, time_first, state
)
else:
...
這里的判斷是如果是decode階段(對比prefill階段)或者非GPU模式執(zhí)行代碼,就使用rwkv6_linear_attention_cpu這個算子,否則就使用優(yōu)化后的實現(xiàn)比如使用這里的cuda kernel(https://github.com/BlinkDL/RWKV-CUDA/tree/main/wkv6)編譯出的CUDA Kernel。flash-linear-attention庫的目的是使用Triton來加速rwkv6_linear_attention_cpu這個naive的實現(xiàn)。這個naive實現(xiàn)的代碼如下:
def hf_rwkv6_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
# For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
# within a torch.no_grad.
batch, seq_length, _ = receptance.shape
num_heads, head_size = time_first.shape
key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2).transpose(-2, -1) # b, t, h, n -> b, h, t, n -> b, h, n, t
value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # b, t, h, n -> b, h, t, n
receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # b, t, h, n -> b, h, t, n
time_decay = torch.exp(-torch.exp(time_decay.float())).view(batch, seq_length, num_heads, head_size).permute(0, 2, 3, 1) # b, t, h, n -> b, h, n, t
time_first = time_first.float().reshape(-1, 1, 1).reshape(num_heads, -1, 1) # h, n -> h * n, 1, 1 -> h, n, 1
out = torch.zeros_like(key).reshape(batch, seq_length, num_heads, head_size)
for current_index in range(seq_length):
current_receptance = receptance[:, :, current_index:current_index+1, :]
current_key = key[:, :, :, current_index:current_index+1]
current_value = value[:, :, current_index:current_index+1, :]
current_time_decay = time_decay[:, :, :, current_index:current_index+1]
attention_output = current_key @ current_value
out[:, current_index] = (current_receptance @ (time_first * attention_output + state)).squeeze(2)
with torch.no_grad():
# attention_output.shape: [b, h, n, 1] x [b, h, 1, n] -> [b, h, n, n]
# current_time_decay * state: [b, h, n, 1] * [b, h, n, n] ->[b, h, n, n]
# state.shape: [b, h, n, n]
state = attention_output + current_time_decay * state
return out, state
這樣看代碼可能會有點懵,可以看下一節(jié)的完整demo測試代碼。
0x2. Profile代碼編寫
上一節(jié)明確了,我們需要加速RWKV模型中rwkv6_linear_attention_cpu的計算,https://github.com/sustcsonglin/flash-linear-attention 這個庫在2024年4月份支持了RWKV6模型,它加速RWKV 6 Linear Attention計算的核心api有兩個,fused_recurrent_rwkv6和chunk_rwkv6?,F(xiàn)在直接寫出profile的代碼(https://github.com/BBuf/flash-rwkv/blob/main/profile/profile_rwkv6_linear_attention.py)來對naive的實現(xiàn),RWKV官方提供的cuda kernel以及fused_recurrent_rwkv6和chunk_rwkv6進行性能分析。
import sys
import torch
from fla.ops.rwkv6.chunk import chunk_rwkv6
from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6
from flash_rwkv import rwkv6_cuda_linear_attention
def hf_rwkv6_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
# For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
# within a torch.no_grad.
batch, seq_length, _ = receptance.shape
num_heads, head_size = time_first.shape
key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2).transpose(-2, -1)
value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
time_decay = torch.exp(-torch.exp(time_decay.float())).view(batch, seq_length, num_heads, head_size).permute(0, 2, 3, 1)
time_first = time_first.float().reshape(-1, 1, 1).reshape(num_heads, -1, 1)
out = torch.zeros_like(key).reshape(batch, seq_length, num_heads, head_size)
for current_index in range(seq_length):
current_receptance = receptance[:, :, current_index:current_index+1, :]
current_key = key[:, :, :, current_index:current_index+1]
current_value = value[:, :, current_index:current_index+1, :]
current_time_decay = time_decay[:, :, :, current_index:current_index+1]
attention_output = current_key @ current_value
out[:, current_index] = (current_receptance @ (time_first * attention_output + state)).squeeze(2)
with torch.no_grad():
state = attention_output + current_time_decay * state
return out, state
if __name__ == "__main__":
mode = sys.argv[1]
B = 1
H = 32
L = 54
D = 64
HIDDEN_SIZE = H * D
dtype = torch.float32
if mode == 'hf':
profile_path = '/bbuf/rwkv_profile_result/hf/'
elif mode == 'recurrent':
profile_path = '/bbuf/rwkv_profile_result/recurrent/'
elif mode == 'chunk':
profile_path = '/bbuf/rwkv_profile_result/chunk/'
elif mode == 'cuda':
profile_path = '/bbuf/rwkv_profile_result/cuda'
else:
raise NotImplementedError
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_path, worker_name='worker0'),
record_shapes=True,
profile_memory=True, # This will take 1 to 2 minutes. Setting it to False could greatly speedup.
with_stack=True
) as p:
for i in range(10):
q = (torch.randn(B, L, HIDDEN_SIZE).cuda().to(torch.float16)).requires_grad_(True)
k = (torch.randn(B, L, HIDDEN_SIZE).cuda().to(torch.float16)).requires_grad_(True)
v = torch.randn(B, L, HIDDEN_SIZE).cuda().to(torch.float16).requires_grad_(True)
w = torch.nn.functional.logsigmoid(torch.randn(B, L, HIDDEN_SIZE)).cuda().to(torch.float32).requires_grad_(True)
u = (torch.randn(H, D).cuda().to(torch.float16)).requires_grad_(True)
state = (torch.randn(B, H, D, D).cuda().to(torch.float32)).requires_grad_(True)
if mode == 'hf':
o1, state1 = hf_rwkv6_linear_attention_cpu(q, k, v, w, u, state)
elif mode =='cuda':
o2, state2 = rwkv6_cuda_linear_attention(q, k, v, w, u.flatten(), state)
else:
batch, seq_length, _ = q.shape
num_heads, head_size = u.shape
k = k.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, T, H, K -> B, H, T, K
v = v.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, T, H, K - > B, H, T, V
q = q.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, H, T, K
w = -torch.exp(w.float()).view(batch, seq_length, num_heads, head_size).permute(0, 2, 1, 3) # B, T, H, K -> B, H, T, K
u = u.float().reshape(num_heads, head_size) # H, K
if mode == 'recurrent':
o3, state3 = fused_recurrent_rwkv6(q, k, v, w, u, initial_state=state, scale=1.0, output_final_state=True)
elif mode == 'chunk':
o4, state4 = chunk_rwkv6(q, k, v, w, u, initial_state=state, scale=1.0, output_final_state=True)
p.step()
這段代碼就是要分別profile hf_rwkv6_linear_attention_cpu,rwkv6_cuda_linear_attention,fused_recurrent_rwkv6,chunk_rwkv6這三個api看一下它們的性能表現(xiàn)以及GPU kernel的詳細使用情況。但這段代碼中有一些需要說明的地方:
-
hf_rwkv6_linear_attention_cpu這個api接收的輸入Tensor形狀和fla包提供的兩個加速api的輸入Tensor形狀不一樣,所以在對hf_rwkv6_linear_attention_cpu設定輸入之后需要經(jīng)過一些維度重排操作才能給fla包的兩個api使用。 -
-
對于 time_decay來說,hf_rwkv6_linear_attention_cpu在計算時做了兩次exp,而fused_recurrent_rwkv6和chunk_rwkv6的api內(nèi)部會做一次exp,所以輸入給fused_recurrent_rwkv6和chunk_rwkv6的time_decay只需要做內(nèi)層的-exp操作就足夠了。 -
對于輸出來說, fused_recurrent_rwkv6和chunk_rwkv6的結果需要轉(zhuǎn)置一下才能得到和hf_rwkv6_linear_attention_cpu一樣的計算結果,state不需要做額外操作,直接就可以對應上。 -
注意api的調(diào)用方式,例如 chunk_rwkv6(q, k, v, w, u, initial_state=state, scale=1.0, output_final_state=True)里面的kwargs是缺一不可的。
接下來就可以執(zhí)行這個profile腳本分別得到這三個api的profile結果了。我在一張NVIDIA A800-SXM4-80GB上進行了profile,結果上傳到了 https://github.com/BBuf/flash-rwkv/tree/main/profile/rwkv_profile_result ,你可以通過 tensorboard --logdir=./rwkv_profile_result/recurrent/ --bind_all 這樣的命令來可視化結果,并在本地的瀏覽器中打開 http://localhost:6006/#pytorch_profiler 網(wǎng)址來查看詳細的結果。
0x3. Profile結果分析
0x3.1 hf_rwkv6_linear_attention_cpu 函數(shù)profile結果
使用hf_rwkv6_linear_attention_cpu函數(shù)進行計算時Kernel部分花了1105us,算子總的時間花了21.5ms,然后它的kernel分布為:
我們可以發(fā)現(xiàn)在kernel里面只有gemv相關的矩陣乘調(diào)用,并且elementwise算子占比非常大已經(jīng)接近40%。
0x3.2 rwkv6_cuda_linear_attention API profile結果
kernel的執(zhí)行時間為73us,算子執(zhí)行的總時間只花了4.5ms,相比于naive的實現(xiàn)(21.5)速度有大幅提升。觀察GPU kernel執(zhí)行情況:
現(xiàn)在
rwkv6_cuda_linear_attention中的核心kernel: kernel_forward執(zhí)行時間為101us。并且現(xiàn)在這個版本只有上面截圖的2個kernel有耗時,剩下的2個elementwise的kernel耗時只有2us。由此可見,使用cuda來編寫和優(yōu)化上面的
rwkv6_cuda_linear_attention api可以獲得大幅度的性能提升。
0x3.3 fused_recurrent_rwkv6 API profile結果
現(xiàn)在Kernel執(zhí)行總時間只有125us,算子總的時間花了5.26ms,相比于naive的實現(xiàn)(21.5)速度有大幅提升,同時kernel的占比也明顯更小,GPU kernel分布情況:
在GPU kernel的具體執(zhí)行分布中,
fused_recurrent_rwkv6_fwd_kernel已經(jīng)是比例的最大的kernel了,而這個kernel的整體耗時非常低只花了64us,而在naive的實現(xiàn)中則存在數(shù)個耗時超過100us的elementwise kernel。目前的整體耗時和優(yōu)化后的cuda kernel實現(xiàn)也是比較接近的。
0x3.4 chunk_rwkv6 API profile結果
chunk_rwkv6的情況和fused_recurrent_rwkv6類似,也是達到了不錯的性能。
0x3.5 Profile結果總結
| 方法 | RWKV 6 Linear Attention端到端耗時(us) | Kernel最大耗時(us) |
|---|---|---|
| hf_rwkv6_linear_attention_cpu | 21500 | 432us |
| rwkv6_cuda_linear_attention | 4500 | 101us |
| fused_recurrent_rwkv6 | 5260 | 64us |
| chunk_rwkv6 | 5602 | 49us |
注:hf_rwkv6_linear_attention_cpu中有很多個耗時比較長的element-wise kernel,性能是最差的,這里只記錄了耗時最長的那個element-wise kernel,已經(jīng)足夠說明問題。后續(xù)三種方案都通過kernel fuse讓hf_rwkv6_linear_attention_cpu實現(xiàn)中的seq_length維度的遍歷和眾多gemv/elemetwise相關kernel最終fuse成1個或者2個kernel。chunk_rwkv6 api的計算分為2個kernel,耗時分別為27和22us,統(tǒng)計kernel最大耗時的時候進行了求和。
結論:手工優(yōu)化的rwkv6_cuda_linear_attention在端到端的耗時方面目前是最快的,從上面的profile代碼也可以看出來主要原因是因為它不需要對各個輸入進行一系列的維度轉(zhuǎn)換,而naive的實現(xiàn)和Triton的實現(xiàn)則必須做一堆維度轉(zhuǎn)換來匹配api提供的計算功能。從Kernel最大耗時的角度看,triton實現(xiàn)的fused_recurrent_rwkv6和chunk_rwkv6 kernel本身的計算是比RWKV-CUDA的手工kernel更快的(雖然還不太清楚Triton實現(xiàn)的版本在編譯中發(fā)生了什么,但真的找到了放棄cuda的理由,畢竟不是專業(yè)做這個東西的,而Triton大家都可以寫),后續(xù)應該會考慮在Triton kernel的基礎上繼續(xù)做優(yōu)化以及訓練性能驗證。
0x4. flash-rwkv庫中的rwkv5_cuda_linear_attention開發(fā)歷程
這里講一下flash-rwkv庫中的rwkv5_cuda_linear_attention這個api背后開發(fā)的迭代歷程。時間回到2023年8月,ChatGPT的火爆讓我也想?yún)⑴c到開源的大模型開發(fā)過程中,然后Peng Bo說可以參與到實現(xiàn)RWKV5 CUDA算子的事情。為了鍛煉下CUDA就開始參與實現(xiàn)和優(yōu)化RWKV5 CUDA,在這個過程中也有幸見識到了RWKV開源社區(qū)中 https://github.com/Blealtan 這位大佬的優(yōu)化水平,同時也了解了Parallel Scan算法和實現(xiàn)。后續(xù)RWKV6的rwkv6_cuda_linear_attention仍然沿用了rwkv5的cuda kernel,只做了微量的坐標映射修改。
HuggingFace中RWKV5模型的Linear Attention Naive實現(xiàn)在 https://huggingface.co/RWKV/rwkv-5-world-1b5/blob/main/modeling_rwkv5.py#L62-L84 ,貼一下這段代碼。
def rwkv5_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
input_dtype = receptance.dtype
# For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
# within a torch.no_grad.
batch, seq_length, hidden_size = receptance.shape
num_heads, head_size = time_first.shape
key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2).transpose(-2, -1)
value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(num_heads, -1, 1)
time_first = time_first.float().reshape(-1, 1, 1).reshape(num_heads, -1, 1)
out = torch.zeros_like(key).reshape(batch, seq_length, num_heads, head_size)
for current_index in range(seq_length):
current_receptance = receptance[:, :, current_index:current_index+1, :]
current_key = key[:, :, :, current_index:current_index+1]
current_value = value[:, :, current_index:current_index+1, :]
attention_output = current_key @ current_value
out[:, current_index] = (current_receptance @ (time_first * attention_output + state)).squeeze(2)
with torch.no_grad():
state = attention_output + time_decay * state
return out, state
要把這段代碼變成cuda kernel,首先需要在形式上做一些還原,使得它更靠近原始的計算公式。還原之后的原始計算公式如下(https://github.com/BlinkDL/RWKV-CUDA/blob/main/wkv5/run.py#L67-L87):
def RUN_FORMULA_1A(B, T, C, H, r, k, v, w, u):
N = C // H
r = r.view(B, T, H, N)
k = k.view(B, T, H, N)
v = v.view(B, T, H, N)
w = w.view(H, N)
u = u.view(H, N)
out = torch.zeros((B, T, H, N), device=DEVICE)
for b in range(B):
for h in range(H):
state = torch.zeros((N,N), device=DEVICE).contiguous()
for t in range(T):
for i in range(N):
for j in range(N):
x = k[b,t,h,j] * v[b,t,h,i]
s = state[i,j]
out[b,t,h,i] += r[b,t,h,j] * (u[h,j] * x + s)
state[i,j] = s * w[h,j] + x
return out.view(B, T, C)
這里有5個循環(huán),其中N一般比較小,對于RWKV5和RWKV6來說,N一般固定為64。還有就是這個還原的公式?jīng)]有返回state,而是在B,H的內(nèi)循環(huán)中申請了一個局部的state,為了保持和上面的公式一致,需要把state的形狀改成[B, H, N, N],就像在profile代碼編寫那一節(jié)看到的這樣。這里的系列kernel暫時不考慮全局state,因為訓練的時候類似于推理的Prefill,不需要有這個state。有了這個代碼之后,只需要想好開多少個Block以及每個Block開多少個Thread就可以寫出一個Baseline了,然后逐步優(yōu)化。
0x4.1 BaseLine
這個是BaseLine kernel的鏈接:https://github.com/BlinkDL/RWKV-CUDA/blob/main/wkv5/cuda/wkv5_cuda_ref.cu
首先看一下Block數(shù)和每個Block的線程數(shù):
void cuda_forward(int B, int T, int C, int H, float *r, float *k, float *v, float *w, float *u, float *y)
{
dim3 threadsPerBlock( min(B*C, 32) );
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, H, r, k, v, w, u, y);
}
每個Block使用min(B*C, 32)個線程,然后Block數(shù)就是B*C//threadsPerBlock.x,上面的公式有5個循環(huán),這里的C=H*N,也就是說這里會把第1個,第2個,第4個循環(huán)分配給CUDA kernel,那么可以預見kernel中每個線程的計算過程肯定還有一個T和N的循環(huán)。瀏覽下這里的cuda kernel:
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _w, const F *__restrict__ const _u,
F *__restrict__ const _y)
{
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _h = (idx / N) % H;
const int _i = idx % N;
const int _o0 = _b*T*C + _h*N;
const int _o1 = _h*N;
const F *__restrict__ const k = _k + _o0;
const F *__restrict__ const v = _v + _o0 + _i;
const F *__restrict__ const r = _r + _o0;
F *__restrict__ const y = _y + _o0 + _i;
float state[N] = {0};
for (int __t = 0; __t < T; __t++)
{
const int _t = __t*C;
const F vv = v[_t];
for (int _j = 0; _j < N; _j++)
{
const int j = _t + _j;
const int m = _o1 + _j;
const float x = k[j] * vv;
const float s = state[_j];
atomicAdd(y + _t, r[j] * (_u[m] * x + s));
state[_j] = s * _w[m] + x;
}
}
}
觀察這個baseline的kernel,首先通過線程id確定當前線程所在的第一循環(huán)b,第二循環(huán)h,第4循環(huán)i的位置,然后對T以及最后的N循環(huán)進行遍歷,按照公式計算結果并使用atomicAdd累計答案。
0x4.1 不必要的atomicAdd
對于每個線程來說它都有唯一的線程id,上面代碼中F *__restrict__ const y = _y + _o0 + _i;這里的_o0+i一定是唯一的,所以這個atomicAdd可以去掉,用一個普通的變量來累加答案即可。https://github.com/BlinkDL/RWKV-CUDA/blob/main/wkv5/cuda/wkv5_cuda_v1a.cu
0x4.2 float4向量化
每個線程會在2個循環(huán)上頻繁訪問數(shù)據(jù)并計算,這里使用float4向量化讀數(shù)據(jù)將有直接的收益。https://github.com/BlinkDL/RWKV-CUDA/blob/main/wkv5/cuda/wkv5_cuda_v1b.cu
0x4.3 線程塊的調(diào)整
在上面的版本中,每個Block的線程數(shù)是min(B*C, 32),而對于RWKV5和RWKV6系列的模型來說,C=H*D=H*64一定是超過32的,所以每個Block的線程數(shù)一定是32,也就是一個warp。從如何設置CUDA Kernel中的grid_size和block_size? 可知線程數(shù)太少會導致SM的Occupancy無法打滿,導致性能變差,最好是每個Block直接開128個線程。但RWKV 5里面的調(diào)整是將每個Block的線程數(shù)調(diào)整到64,具體見:https://github.com/BlinkDL/RWKV-CUDA/blob/main/wkv5_bf16/cuda/wkv5_cuda_v1b.cu
0x4.4 Shared Memory
觀察到在第三和第五兩個循環(huán)下,會頻繁訪問r, k, u, w,因此可以把這幾個數(shù)據(jù)存入shared memory再讀取。https://github.com/BlinkDL/RWKV-CUDA/blob/main/wkv5_bf16/cuda/wkv5_cuda_v1b.cu
template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
F *__restrict__ const _y)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;
_w += h*_N_;
_u += h*_N_;
__shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
float state[_N_] = {0};
__syncthreads();
u[i] = float(_u[i]);
w[i] = float(_w[i]);
__syncthreads();
for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
{
__syncthreads();
r[i] = float(_r[t]);
k[i] = float(_k[t]);
__syncthreads();
const float v = float(_v[t]);
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j+=4)
{
const float4& r_ = (float4&)(r[j]);
const float4& k_ = (float4&)(k[j]);
const float4& w_ = (float4&)(w[j]);
const float4& u_ = (float4&)(u[j]);
float4& s = (float4&)(state[j]);
float4 x;
x.x = k_.x * v;
x.y = k_.y * v;
x.z = k_.z * v;
x.w = k_.w * v;
y += r_.x * (u_.x * x.x + s.x);
y += r_.y * (u_.y * x.y + s.y);
y += r_.z * (u_.z * x.z + s.z);
y += r_.w * (u_.w * x.w + s.w);
s.x = s.x * w_.x + x.x;
s.y = s.y * w_.y + x.y;
s.z = s.z * w_.z + x.z;
s.w = s.w * w_.w + x.w;
}
_y[t] = F(y);
}
}
這里如果想把state也存入shared memory,那么state就需要做成一個全局的state這樣才可以只開N的大小否則就需要開N*N的大小導致SM上shared memory大小不夠。
每個Block開啟了64個線程,也就是2個warp,對于warp里面的每個線程來說,它在訪問r, k, u, w的時候必定是獨立且連續(xù)的,因為這些訪問都在N這個循環(huán)中,不會發(fā)生Bank Conflict。
這就是rwkv5_cuda_linear_attention對應的cuda kernel目前的狀態(tài)。但,怎么就被Triton秒了?
0x5. Triton實現(xiàn)粗略瀏覽
Triton的實現(xiàn)也是根據(jù)naive的實現(xiàn)來的,先看一下naive的實現(xiàn)以及相關的輸入。https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/ops/rwkv6/recurrent_naive.py#L8-L36
def naive_recurrent_rwkv6(
q,
k,
v,
w,
u,
initial_state=None,
output_final_state=False
):
orig_dtype = q.dtype
q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u))
batch_size, n_heads, seq_len, d_head_k = q.shape
_, _, _, d_head_v = v.shape
h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device)
o = torch.zeros_like(v)
if initial_state is not None:
h += initial_state
for i in range(seq_len):
q_i = q[:, :, i, :]
k_i = k[:, :, i]
v_i = v[:, :, i, :]
w_i = w[:, :, i].exp()
kv_i = k_i[..., None] * v_i[..., None, :]
o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None]
o[:, :, i] = o_i.sum(-2)
h = h * w_i[..., None] + kv_i
return o.to(orig_dtype)
q, k, v, w, u等定義如下:
B = 4
H = 4
L = 1024
D = 100
dtype = torch.float32
q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True)
v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True)
w = torch.nn.functional.logsigmoid(torch.randn(B, H, L, D)).cuda().to(torch.float32).requires_grad_(True)
u = (torch.randn(H, D).cuda().to(dtype)).requires_grad_(True)
do = torch.rand_like(v).cuda()
o = naive_recurrent_rwkv6(q, k, v, w, u)
這里q,k,v的head dim維度我重新設置為了D。
然后在實現(xiàn)fused_recurrent_rwkv6的時候各個輸入tensor的shape也沿用了這里的設置。接口定義在 https://github.com/sustcsonglin/flash-linear-attention/blob/main/fla/ops/rwkv6/recurrent_fuse.py#L403 。
# if scale is None, use d_head_qk ** -0.5 by default. Otherwise specify the scale yourself. e.g. scale = 1.0
def fused_recurrent_rwkv6(
r: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
scale: int = -1,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
causal: bool = True
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
r (torch.Tensor):
reception of shape `(B, H, T, K)`. Alias: q, query in linear attention.
k (torch.Tensor):
keys of shape `(B, H, T, K)`
v (torch.Tensor):
values of shape `(B, H, T, V)`
w (torch.Tensor):
data-dependent decays of shape `(B, H, T, K)` in log space! Alias: g.
u (torch.Tensor):
bonus of shape `(H, K)`
scale (Optional[int]):
Scale factor for the RWKV6 attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `(B, H, K, V)`. Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
"""
if scale == -1:
scale = r.shape[-1] ** -0.5
if initial_state is not None:
initial_state = initial_state.detach()
o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)
return o, final_state
這里再關注下Triton實現(xiàn)的Kernel的線程網(wǎng)格設置相關代碼,也就是FusedRecurrentRWKV6Function的forward函數(shù):
class FusedRecurrentRWKV6Function(torch.autograd.Function):
@staticmethod
@contiguous
@custom_fwd
def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):
# alias
q = r
batch_size, n_heads, seq_len, d_head_qk = q.shape
d_head_v = v.shape[-1]
# default scale
if scale is None:
scale = d_head_qk ** -0.5
BK, BV = min(triton.next_power_of_2(d_head_qk), 32), min(triton.next_power_of_2(d_head_v), 32)
NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)
num_stages = 1
num_warps = 1
o = q.new_empty(NK, batch_size, n_heads, seq_len,
d_head_v, dtype=torch.float32)
if output_final_state:
final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)
else:
final_state = None
grid = (NV, NK, batch_size * n_heads)
fused_recurrent_rwkv6_fwd_kernel[grid](
q, k, v, w, u, o, initial_state, final_state,
q.stride(1), q.stride(2), q.stride(3),
v.stride(1), v.stride(2), v.stride(3),
batch_size, n_heads, seq_len, scale,
DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,
USE_INITIAL_STATE=initial_state is not None,
STORE_FINAL_STATE=final_state is not None,
REVERSE=reverse,
num_warps=num_warps,
num_stages=num_stages
)
o = o.sum(0)
ctx.save_for_backward(q, k, v, w, u, initial_state, o)
ctx.scale = scale
ctx.reverse = reverse
# we do not need the gradient of the final state from the next chunk
# similiar to Trunctated BPTT
if final_state is not None:
final_state = final_state.detach()
return o.to(q.dtype), final_state
根據(jù)提供的輸入形狀,我們可以推導出以下參數(shù):
-
B(batch size)= 4 -
H(number of heads)= 4 -
L(sequence length)= 1024 -
D(head dimension)= 100
我們可以使用這些參數(shù)來計算 BK 和 BV 的值,以及 NK 和 NV 的值:
-
BK=min(triton.next_power_of_2(D), 32)=min(128, 32)=32 -
BV=min(triton.next_power_of_2(D, 32)=min(200, 32)=32 -
NK=triton.cdiv(D, BK)=triton.cdiv(100, 32)=4 -
NV=triton.cdiv(D, BV)=triton.cdiv(100, 32)=4
根據(jù)這些值,我們可以推導出 grid 的大小。根據(jù)代碼中的定義,grid 是一個三元組,表示 Triton Kernel 的線程網(wǎng)格大小,其中包括 (NV, NK, batch_size * n_heads)。
在這個例子中,batch_size * n_heads = 4 * 4 = 16。因此,grid 的大小將是 (4, 4, 16),相當于有256個Block在并行計算,而每個Block的內(nèi)部目前Triton的Kernel中指定的是1個warp也就是32個進程來計算。
而在RWKV-CUDA的實現(xiàn)中,對于這個case一共會使用16個線程塊,然后每個線程塊使用100個線程,從直覺上看這就是一個很不好的配置,Block數(shù)太小無法用滿SM。
Triton的kernel后續(xù)在接著學習和分析,我也需要認真學習下triton。
0x6. 總結
關于flash-linear-attention中rwkv6加速算子的實現(xiàn)后面再解析吧,后續(xù)如果RWKV6的Linear Attention算子優(yōu)化在開源社區(qū)有新的進展,我也會及時跟進和分享給大家。
