1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        挑戰(zhàn)Transformer!Mamba的架構(gòu)及實現(xiàn)(Pytorch)

        共 25939字,需瀏覽 52分鐘

         ·

        2024-04-10 14:53

        Mamba一經(jīng)出現(xiàn)就在人工智能界掀起波瀾,被吹捧為Transformer的競爭對手。 到底是什么讓Mamba在擁擠的序列建模中脫穎而出?  今天我們來詳細研究這篇論文《Mamba:具有選擇性狀態(tài)空間的線性時間序列建模


        e0e036862f6cd5ce85d71fd3c1225313.webp

        在介紹之前先簡要回顧一下現(xiàn)有的模型

        6429540eceefc8c0222a43db3b9a6ce0.webp

        Transformer:以其注意力機制而聞名,其中序列的任何部分都可以動態(tài)地與任何其他部分相互作用,特別是具有因果注意力機制的的Transformer,擅長處理序列中的單個元素。但是它們帶來了顯著的計算和內(nèi)存成本,與序列長度的平方(L2)成比例。

        循環(huán)神經(jīng)網(wǎng)絡(rnn): rnn只考慮當前輸入和最后一個隱藏狀態(tài),按順序更新隱藏狀態(tài)。這種方法允許它們潛在地處理無限序列長度和恒定的內(nèi)存需求。但是rnn的簡單性是一個缺點,限制了它們記住長期依賴關系的能力。此外,rnn中的時間反向傳播(BPTT)是內(nèi)存密集型的,并且可能遭受梯度消失或爆炸的影響,盡管有LSTM等創(chuàng)新部分結(jié)解決了這個問題。

        State Space Models(S4):這些模型已經(jīng)顯示出很好的特性。它們提供了一種平衡,比rnn更有效地捕獲遠程依賴關系,同時比transformer更高效地使用內(nèi)存。

        接下來Manba登場!

        Mamba

        選擇性狀態(tài)空間:Mamba建立在狀態(tài)空間模型的概念之上,但引入了一個新的變化。它利用選擇性狀態(tài)空間,支持跨長序列更高效和有效地捕獲相關信息。

        線性時間復雜度:與Transformer不同,Mamba在序列長度方面以線性時間運行。這個屬性使得它特別適合涉及非常長的序列的任務,而傳統(tǒng)模型在這方面會遇到困難。

        9f6a35b129ac808d5fd885c193ced12b.webp

        Mamba以其選擇性狀態(tài)空間的概念引入了傳統(tǒng)狀態(tài)空間模型的一個有趣的改進。這種方法稍微放松了標準狀態(tài)空間模型的嚴格狀態(tài)轉(zhuǎn)換,使其更具適應性和靈活性(有點類似于lstm)。并且Mamba保留了狀態(tài)空間模型的高效計算特性,使其能夠在一次掃描中執(zhí)行整個序列的前向傳遞-這一特性更讓人想起Transformer。

        在訓練期間,Mamba的行為類似于Transformer,同時處理整個序列。而lstm必須一步一步地計算前向傳遞,即使所有輸入都是已知的。在推理中,Mamba的行為更符合傳統(tǒng)的循環(huán)模型,提供有效的序列處理。

        先驗狀態(tài)空間模型(ssm)的一個關鍵限制是其剛性的、輸入不變的結(jié)構(gòu)。這些模型為整個序列使用一組固定參數(shù)(我們稱它們?yōu)閍和B)。這種結(jié)構(gòu)甚至比lstm等模型更具限制性,在lstm中,信號的轉(zhuǎn)換可能依賴于先前的隱藏狀態(tài)和輸入。

        Mamba則一種范式轉(zhuǎn)換,即如何計算向下一個隱藏狀態(tài)的過渡?在Mamba的體系結(jié)構(gòu)中,轉(zhuǎn)換依賴于當前輸入,這種方法在傳統(tǒng)ssm的固定計算和循環(huán)神經(jīng)網(wǎng)絡的輸入依賴動態(tài)性之間取得了平衡。

        主要組成如下:

        固定主干:從一個隱藏狀態(tài)到下一個隱藏狀態(tài)的轉(zhuǎn)換仍然是一個固定的計算(由a矩陣定義),允許跨序列的預計算。

        輸入相關轉(zhuǎn)換:輸入影響下一個隱藏狀態(tài)(由B矩陣定義)的方式取決于當前輸入,而不是之前的隱藏狀態(tài)。與傳統(tǒng)ssm相比,這種輸入依賴性提供了更大的靈活性。

        7e087b8702e178f317dd35594a3b0180.webp

        為了滿足這種方法的計算需求,Mamba使用了一種硬件感知算法。該算法使用掃描操作而不是卷積來循環(huán)執(zhí)行計算,這樣在gpu上非常高效的。盡管輸入依賴轉(zhuǎn)換帶來了算法復雜性,但這種效率對于保持高性能至關重要。

        Mamba和選擇性狀態(tài)空間模型不是同義詞。Mamba是一個使用選擇性狀態(tài)空間概念的實現(xiàn)。這種區(qū)別是至關重要的,因為它突出了Mamba的獨特貢獻:在保持計算效率的同時,使SSM框架更加靈活和響應輸入。

        SRAM和HBM

        98b03bda360130e8d6a4717748667ed9.webp

        gpu包含兩種主要類型的內(nèi)存:HBM (High Bandwidth memory)和SRAM (Static Random-Access memory)。HBM雖然帶寬很高,但與更快但更小的SRAM相比,它的訪問時間相對較慢。Mamba則使用SRAM在矩陣乘法期間進行快速訪問,這是其計算的關鍵。

        計算中的主要瓶頸通常不是計算本身,而是數(shù)據(jù)在內(nèi)存類型之間的移動。Mamba通過顯著減少傳輸大量數(shù)據(jù)的需求來解決這個問題。它通過直接在SRAM中執(zhí)行算法的關鍵部分(如離散化和遞歸計算)來實現(xiàn),從而減少延遲。

        還引入了一個融合選擇掃描層,使其內(nèi)存需求與使用flash attention的優(yōu)化Transformer實現(xiàn)相當。這一層對于保持效率至關重要,尤其是在處理模型中依賴于輸入的元素時。

        結(jié)果

        5fc00a631acf0ef27c084d138d8d40f1.webp

        Mamba代表了序列建模的重大進步,特別是在其高效使用GPU內(nèi)存和計算策略方面。它具有高效率處理長序列的能力,使其成為各種應用的有前途的模型,我們下面來使用Pytorch代碼來對其進復現(xiàn)。

        Pytorch復現(xiàn)

        導入基本庫

         import torch
         import torch.nn as nn
         import torch.optim as optim
         from torch.utils.data import DataLoader, Dataset
         from torch.nn import functional as F
         from einops import rearrange
         from tqdm import tqdm
         
         import math
         import os
         import urllib.request
         from zipfile import ZipFile
         
         from transformers import AutoTokenizer
         
         torch.autograd.set_detect_anomaly(True)

        設置標志和超參數(shù)

         # Configuration flags and hyperparameters
         USE_MAMBA = 1
         DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0
         
         device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        定義超參數(shù)和初始化

         d_model = 8
         state_size = 128 # Example state size
         seq_len = 100 # Example sequence length
         batch_size = 256 # Example batch size
         last_batch_size = 81 # only for the very last batch of the dataset
         current_batch_size = batch_size
         different_batch_size = False
         h_new = None
         temp_buffer = None

        這里的超參數(shù),如模型維度(d_model)、狀態(tài)大小、序列長度和批大小。

        S6模塊是Mamba架構(gòu)中的一個復雜組件,負責通過一系列線性變換和離散化過程處理輸入序列。它在捕獲序列的時間動態(tài)方面起著關鍵作用,這是序列建模任務(如語言建模)的一個關鍵方面。這里包括張量運算和自定義離散化方法來處理序列數(shù)據(jù)的復雜需求。

         class S6(nn.Module):
             def __init__(self, seq_len, d_model, state_size, device):
                 super(S6, self).__init__()
         
                 self.fc1 = nn.Linear(d_model, d_model, device=device)
                 self.fc2 = nn.Linear(d_model, state_size, device=device)
                 self.fc3 = nn.Linear(d_model, state_size, device=device)
         
                 self.seq_len = seq_len
                 self.d_model = d_model
                 self.state_size = state_size
         
         
                 self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1))
                 nn.init.xavier_uniform_(self.A)
         
                 self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
                 self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
         
                 self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
                 self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
                 self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
         
                 # h [batch_size, seq_len, d_model, state_size]
                 self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
                 self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
         
         
             def discretization(self):
         
                 self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B)
         
                 self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A))
         
         
                 return self.dA, self.dB
         
             def forward(self, x):
                 # Algorithm 2 MAMBA paper
                 self.B = self.fc2(x)
                 self.C = self.fc3(x)
                 self.delta = F.softplus(self.fc1(x))
         
                 self.discretization()
         
                 if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:  
                   
                     global current_batch_size
                     current_batch_size = x.shape[0]
         
                     if self.h.shape[0] != current_batch_size:
                         different_batch_size = True
         
                         h_new =  torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB
         
                     else:
                         different_batch_size = False
                         h_new =  torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB
         
                     # y [batch_size, seq_len, d_model]
                     self.y = torch.einsum('bln,bldn->bld', self.C, h_new)
         
                     global temp_buffer
                     temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone()
           
                     return self.y
         
                 else:  
                     # h [batch_size, seq_len, d_model, state_size]
                     h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device)
                     y = torch.zeros_like(x)
         
                     h =  torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB
         
                     # y [batch_size, seq_len, d_model]
                     y = torch.einsum('bln,bldn->bld', self.C, h)
         
                     return y

        這個S6的模塊,可以處理離散化過程和正向傳播。

        MambaBlock類是一個定制的神經(jīng)網(wǎng)絡模塊,被設計為Mamba模型的關鍵構(gòu)建塊。它封裝了幾個層和操作來處理輸入數(shù)據(jù)。

        包括線性投影、卷積、激活函數(shù)、自定義S6模塊和殘差連接。該塊是Mamba模型的基本組件,負責通過一系列轉(zhuǎn)換處理輸入序列,以捕獲數(shù)據(jù)中的相關模式和特征。這些不同層和操作的組合允許MambaBlock有效地處理復雜的序列建模任務。MambaBlock是Mamba核心功能。

         class MambaBlock(nn.Module):
             def __init__(self, seq_len, d_model, state_size, device):
                 super(MambaBlock, self).__init__()
         
                 self.inp_proj = nn.Linear(d_model, 2*d_model, device=device)
                 self.out_proj = nn.Linear(2*d_model, d_model, device=device)
         
                 # For residual skip connection
                 self.D = nn.Linear(d_model, 2*d_model, device=device)
         
                 # Set _no_weight_decay attribute on bias
                 self.out_proj.bias._no_weight_decay = True
         
                 # Initialize bias to a small constant value
                 nn.init.constant_(self.out_proj.bias, 1.0)
         
                 self.S6 = S6(seq_len, 2*d_model, state_size, device)
         
                 # Add 1D convolution with kernel size 3
                 self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, device=device)
         
                 # Add linear layer for conv output
                 self.conv_linear = nn.Linear(2*d_model, 2*d_model, device=device)
         
                 # rmsnorm
                 self.norm = RMSNorm(d_model, device=device)
         
             def forward(self, x):
                 """
                x_proj.shape = torch.Size([batch_size, seq_len, 2*d_model])
                x_conv.shape = torch.Size([batch_size, seq_len, 2*d_model])
                x_conv_act.shape = torch.Size([batch_size, seq_len, 2*d_model])
                """
                 # Refer to Figure 3 in the MAMBA paper
         
                 x = self.norm(x)
         
                 x_proj = self.inp_proj(x)
         
                 # Add 1D convolution with kernel size 3
                 x_conv = self.conv(x_proj)
         
                 x_conv_act = F.silu(x_conv)
         
                 # Add linear layer for conv output
                 x_conv_out = self.conv_linear(x_conv_act)
         
                 x_ssm = self.S6(x_conv_out)
                 x_act = F.silu(x_ssm)  # Swish activation can be implemented as x * sigmoid(x)
         
                 # residual skip connection with nonlinearity introduced by multiplication
                 x_residual = F.silu(self.D(x))
         
                 x_combined = x_act * x_residual
         
                 x_out = self.out_proj(x_combined)
         
                 return x_out


        Mamba模型

        包括一系列MambaBlock模塊。每個塊都順序處理輸入數(shù)據(jù),一個塊的輸出作為下一個塊的輸入。這種順序處理允許模型捕獲輸入數(shù)據(jù)中的復雜模式和關系,使其對涉及順序建模的任務有效。多個塊的堆疊是深度學習架構(gòu)中的常見設計,因為它使模型能夠?qū)W習數(shù)據(jù)的分層表示。

         class Mamba(nn.Module):
             def __init__(self, seq_len, d_model, state_size, device):
                 super(Mamba, self).__init__()
                 self.mamba_block1 = MambaBlock(seq_len, d_model, state_size, device)
                 self.mamba_block2 = MambaBlock(seq_len, d_model, state_size, device)
                 self.mamba_block3 = MambaBlock(seq_len, d_model, state_size, device)
         
             def forward(self, x):
                 x = self.mamba_block1(x)
                 x = self.mamba_block2(x)
                 x = self.mamba_block3(x)
                 return x

        RMSNorm是一個自定義規(guī)范化層,這一層用于規(guī)范神經(jīng)網(wǎng)絡的激活,這可以幫助穩(wěn)定和加快訓練。

         class RMSNorm(nn.Module):
             def __init__(self,
                          d_model: int,
                          eps: float = 1e-5,
                          device: str ='cuda'):
                 super().__init__()
                 self.eps = eps
                 self.weight = nn.Parameter(torch.ones(d_model, device=device))
         
         
             def forward(self, x):
                 output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
         
                 return output

        這一層的用法:

         x = torch.rand(batch_size, seq_len, d_model, device=device)
         # Create the Mamba model
         mamba = Mamba(seq_len, d_model, state_size, device)
         
         # rmsnorm
         norm = RMSNorm(d_model)
         x = norm(x)
         
         # Forward pass
         test_output = mamba(x)
         print(f"test_output.shape = {test_output.shape}")  # Should be [batch_size, seq_len, d_model]

        上面就是模型的全部基本代碼,下面就可以進行數(shù)據(jù)準備和訓練

        我們自定義一個Enwiki8Dataset

         class Enwiki8Dataset(Dataset):
             def __init__(self, data):
                 self.data = data
         
             def __len__(self):
                 return len(self.data['input_ids'])
         
             def __getitem__(self, idx):
                 item = {key: val[idx].clone().detach() for key, val in self.data.items()}
                 return item

        pad_sequences_3d用于將一批序列填充到統(tǒng)一的長度,確保批中的每個序列具有相同數(shù)量的元素(或時間步長)。這在許多機器學習任務中尤其重要,因為輸入數(shù)據(jù)必須具有一致的形狀。

         # Define a function for padding
         def pad_sequences_3d(sequences, max_len=None, pad_value=0):
             # Assuming sequences is a tensor of shape (batch_size, seq_len, feature_size)
             batch_size, seq_len, feature_size = sequences.shape
         
             if max_len is None:
                 max_len = seq_len + 1
         
         
             # Initialize padded_sequences with the pad_value
             padded_sequences = torch.full((batch_size, max_len, feature_size), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device)
             # Pad each sequence to the max_len
             padded_sequences[:, :seq_len, :] = sequences
         
             return padded_sequences

        訓練過程:

         def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):
             model.train()
             total_loss = 0
             for batch in data_loader:
                 optimizer.zero_grad()
         
                 input_data = batch['input_ids'].clone().to(device)
                 attention_mask = batch['attention_mask'].clone().to(device)
         
                 target = input_data[:, 1:]
                 input_data = input_data[:, :-1]
         
                 # Pad all the sequences in the batch:
                 input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
                 target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)
         
                 if USE_MAMBA:
                     output = model(input_data)
                     loss = criterion(output, target)
         
                 loss.backward(retain_graph=True)
         
                 for name, param in model.named_parameters():
                    if 'out_proj.bias' not in name:
                        # clip weights but not bias for out_proj
                        torch.nn.utils.clip_grad_norm_(param, max_norm=max_grad_norm)
         
                 if DEBUGGING_IS_ON:
                     for name, parameter in model.named_parameters():
                         if parameter.grad is not None:
                             print(f"{name} gradient: {parameter.grad.data.norm(2)}")
                         else:
                             print(f"{name} has no gradient")
         
                 if USE_MAMBA and DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:
                     model.S6.h[:current_batch_size, ...].copy_(temp_buffer)
         
                 optimizer.step()
         
                 total_loss += loss.item()
             return total_loss / len(data_loader)

        評估函數(shù):

         def evaluate(model, data_loader, criterion, device):
             model.eval()
             total_loss = 0
             with torch.no_grad():
                 for batch in data_loader:
                     input_data = batch['input_ids'].clone().detach().to(device)
                     attention_mask = batch['attention_mask'].clone().detach().to(device)
         
                     target = input_data[:, 1:]
                     input_data = input_data[:, :-1]
         
                     # Pad all the sequences in the batch:
                     input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
                     target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)
         
                     if USE_MAMBA:
                         output = model(input_data)
                         loss = criterion(output, target)
                     total_loss += loss.item()
             return total_loss / len(data_loader)

        最后,calculate_perplexity用于評估語言模型(如Mamba)的性能。

         def calculate_perplexity(loss):
             return math.exp(loss)

        load_enwiki8_dataset函數(shù)用于下載和提取enwiki8數(shù)據(jù)集,該數(shù)據(jù)集通常用于對語言模型進行基準測試。

         def load_enwiki8_dataset():
             print(f"Download and extract enwiki8 data")
             url = "http://mattmahoney.net/dc/enwik8.zip"
             urllib.request.urlretrieve(url, "enwik8.zip")
         
             with ZipFile("enwik8.zip") as f:
                 data = f.read("enwik8").decode("utf-8")
         
             return data

        encode_dataset函數(shù)設計用于標記和編碼數(shù)據(jù)集,為神經(jīng)網(wǎng)絡模型(如Mamba)處理數(shù)據(jù)集做準備。

         # Tokenize and encode the dataset
         def encode_dataset(tokenizer, text_data):
             def batch_encode(tokenizer, text_data, batch_size=1000):
                 # Tokenize in batches
                 batched_input_ids = []
                 for i in range(0, len(text_data), batch_size):
                     batch = text_data[i:i+batch_size]
                     inputs = tokenizer(batch, add_special_tokens=True, truncation=True,
                                        padding='max_length', max_length=seq_len,
                                        return_tensors='pt')
                     batched_input_ids.append(inputs['input_ids'])
                 return torch.cat(batched_input_ids)
         
             # Assuming enwiki8_data is a list of sentences
             input_ids = batch_encode(tokenizer, enwiki8_data)
         
             # vocab_size is the number of unique tokens in the tokenizer's vocabulary
             global vocab_size
             vocab_size = len(tokenizer.vocab)  # Note that for some tokenizers, we might access the vocab directly
             print(f"vocab_size = {vocab_size}")
         
             # Create an embedding layer
             # embedding_dim is the size of the embedding vectors (MAMBA model's D)
             embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)
         
             # Pass `input_ids` through the embedding layer
             # This will change `input_ids` from shape [B, L] to [B, L, D]
             def batch_embedding_calls(input_ids, embedding_layer, batch_size=256):
                 # Check if input_ids is already a tensor, if not convert it
                 if not isinstance(input_ids, torch.Tensor):
                     input_ids = torch.tensor(input_ids, dtype=torch.long)
         
                 # Calculate the number of batches needed
                 num_batches = math.ceil(input_ids.size(0) / batch_size)
         
                 # List to hold the output embeddings
                 output_embeddings = []
         
                 # Process each batch
                 for i in range(num_batches):
                     # Calculate start and end indices for the current batch
                     start_idx = i * batch_size
                     end_idx = start_idx + batch_size
         
                     # Get the batch
                     input_id_batch = input_ids[start_idx:end_idx]
         
                     # Call the embedding layer
                     with torch.no_grad():  # No need gradients for this operation
                         batch_embeddings = embedding_layer(input_id_batch)
         
                     # Append the result to the list
                     output_embeddings.append(batch_embeddings)
         
                 # Concatenate the embeddings from each batch into a single tensor
                 all_embeddings = torch.cat(output_embeddings, dim=0)
         
                 return all_embeddings
         
             # `input_ids` is a list or tensor of the input IDs and `embedding_layer` is model's embedding layer
             if USE_MAMBA:
                 # Set `batch_size` to a value that works for memory constraints
                 encoded_inputs = batch_embedding_calls(input_ids, embedding_layer, batch_size=1).float()
         
             attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.dtype)
         
             return encoded_inputs, attention_mask

        下面就可以進行訓練了

         # Load a pretrained tokenizer
         tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
         
         # Assuming encoded_inputs is a preprocessed tensor of shape [num_samples, seq_len, d_model]
         encoded_inputs_file = 'encoded_inputs_mamba.pt'
         
         
         if os.path.exists(encoded_inputs_file):
             print("Loading pre-tokenized data...")
             encoded_inputs = torch.load(encoded_inputs_file)
         else:
             print("Tokenizing raw data...")
             enwiki8_data = load_enwiki8_dataset()
             encoded_inputs, attention_mask = encode_dataset(tokenizer, enwiki8_data)
             torch.save(encoded_inputs, encoded_inputs_file)
             print(f"finished tokenizing data")
         
         
         # Combine into a single dictionary
         data = {
             'input_ids': encoded_inputs,
             'attention_mask': attention_mask
         }
         
         # Split the data into train and validation sets
         total_size = len(data['input_ids'])
         train_size = int(total_size * 0.8)
         
         train_data = {key: val[:train_size] for key, val in data.items()}
         val_data = {key: val[train_size:] for key, val in data.items()}
         
         train_dataset = Enwiki8Dataset(train_data)
         val_dataset = Enwiki8Dataset(val_data)
         
         
         train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
         val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
         
         
         # Initialize the model
         
         model = Mamba(seq_len, d_model, state_size, device).to(device)
         
         # Define the loss function and optimizer
         criterion = nn.CrossEntropyLoss()
         optimizer = optim.AdamW(model.parameters(), lr=5e-6)
         
         # Training loop
         num_epochs = 25  # Number of epochs to train for
         
         for epoch in tqdm(range(num_epochs)):  # loop over the dataset multiple times
             train_loss = train(model, tokenizer, train_loader, optimizer, criterion, device, max_grad_norm=10.0, DEBUGGING_IS_ON=False)
             val_loss = evaluate(model, val_loader, criterion, device)
             val_perplexity = calculate_perplexity(val_loss)
             print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}')

        以上就是訓練的完整代碼

        總結(jié)

        我們介紹了Mamba的概念和架構(gòu),并且從頭開始構(gòu)建Mamba復現(xiàn),這樣可以將理論轉(zhuǎn)化為實踐。通過這種動手的方法,可以看到Mamba序列建模方法和效率。如果你想直接使用,可以看論文提供的代碼。

        論文地址:

        https://arxiv.org/abs/2312.00752

        論文提供的源代碼:

        https://github.com/state-spaces/mamba

        瀏覽 133
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        評論
        圖片
        表情
        推薦
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            日韩一级黄色视频 | 乱色综合| 亚洲精品久久久久玩吗 | 在线观看亚洲AV无码 | 中日韩欧美在线 | 日韩久久不卡 | 掀开白丝袜jk裙子扒掉内裤 | 久久久久久看片 | 国产激情自拍 | 淫乱91 |