10條PyTorch避坑指南
點(diǎn)擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)

本文轉(zhuǎn)載自:機(jī)器之心 | 作者:Eugene Khvedchenya
高性能 PyTorch 的訓(xùn)練管道是什么樣的?是產(chǎn)生最高準(zhǔn)確率的模型?是最快的運(yùn)行速度?是易于理解和擴(kuò)展?還是容易并行化?答案是,包括以上提到的所有。

建議 0:了解你代碼中的瓶頸在哪里
建議 1:如果可能的話,將數(shù)據(jù)的全部或部分移至 RAM。
class RAMDataset(Dataset):def __init__(image_fnames, targets):self.targets = targetsself.images = []for fname in tqdm(image_fnames, desc="Loading files in RAM"):with open(fname, "rb") as f:self.images.append(f.read())def __len__(self):return len(self.targets)def __getitem__(self, index):target = self.targets[index]image, retval = cv2.imdecode(self.images[index], cv2.IMREAD_COLOR)return image, target
建議 2:解析、度量、比較。每次你在管道中提出任何改變,要深入地評估它全面的影響。
# Profile CPU bottleneckspython -m cProfile training_script.py --profiling# Profile GPU bottlenecksnvprof --print-gpu-trace python train_mnist.py# Profile system calls bottlenecksstrace -fcT python training_script.py -e trace=open,close,readAdvice 3: *Preprocess everything offline*
建議 3:離線預(yù)處理所有內(nèi)容
建議 4:調(diào)整 DataLoader 的工作程序
假設(shè)我們?yōu)?Cityscapes 訓(xùn)練圖像分割模型,其批處理大小為 32,RGB 圖像大小是 512x512x3(高、寬、通道)。我們在 CPU 端進(jìn)行圖像標(biāo)準(zhǔn)化(稍后我將會(huì)解釋為什么這一點(diǎn)比較重要)。在這種情況下,我們最終的圖像 tensor 將會(huì)是 512 * 512 * 3 * sizeof(float32) = 3,145,728 字節(jié)。與批處理大小相乘,結(jié)果是 100,663,296 字節(jié),大約 100Mb;
除了圖像之外,我們還需要提供 ground-truth 掩膜。它們各自的大小為(默認(rèn)情況下,掩膜的類型是 long,8 個(gè)字節(jié))——512 * 512 * 1 * 8 * 32 = 67,108,864 或者大約 67Mb;
因此一批數(shù)據(jù)所需要的總內(nèi)存是 167Mb。假設(shè)有 8 個(gè)工作程序,內(nèi)存的總需求量將是 167 Mb * 8 = 1,336 Mb。
將 RGB 圖像保持在每個(gè)通道深度 8 位??梢暂p松地在 GPU 上將圖像轉(zhuǎn)換為浮點(diǎn)形式或者標(biāo)準(zhǔn)化。
在數(shù)據(jù)集中用 uint8 或 uint16 數(shù)據(jù)類型代替 long。
class MySegmentationDataset(Dataset):...def __getitem__(self, index):image = cv2.imread(self.images[index])target = cv2.imread(self.masks[index])# No data normalization and type casting herereturn torch.from_numpy(image).permute(2,0,1).contiguous(),torch.from_numpy(target).permute(2,0,1).contiguous()class Normalize(nn.Module):# https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/modules/normalize.pydef __init__(self, mean, std):super().__init__()self.register_buffer("mean", torch.tensor(mean).float().reshape(1, len(mean), 1, 1).contiguous())self.register_buffer("std", torch.tensor(std).float().reshape(1, len(std), 1, 1).reciprocal().contiguous())def forward(self, input: torch.Tensor) -> torch.Tensor:return (input.to(self.mean.type) - self.mean) * self.stdclass MySegmentationModel(nn.Module):def __init__(self):self.normalize = Normalize([0.221 * 255], [0.242 * 255])self.loss = nn.CrossEntropyLoss()def forward(self, image, target):image = self.normalize(image)output = self.backbone(image)if target is not None:loss = self.loss(output, target.long())return lossreturn output

model = nn.DataParallel(model) # Runs model on all available GPUsGPU 負(fù)載不平衡;
在主 GPU 上聚合需要額外的視頻內(nèi)存
在訓(xùn)練期間繼續(xù)在前向推導(dǎo)內(nèi)使用 nn.DataParallel 計(jì)算損耗。在這種情況下。za 不會(huì)將密集的預(yù)測掩碼返回給主 GPU,而只會(huì)返回單個(gè)標(biāo)量損失;
使用分布式訓(xùn)練,也稱為 nn.DistributedDataParallel。借助分布式訓(xùn)練的另一個(gè)好處是可以看到 GPU 實(shí)現(xiàn) 100% 負(fù)載。
https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255
https://medium.com/@theaccelerators/learn-pytorch-multi-gpu-properly-3eb976c030ee
https://towardsdatascience.com/how-to-scale-training-on-multiple-gpus-dae1041f49d2
建議 5: 如果你擁有兩個(gè)及以上的 GPU
def test_loss_profiling():loss = nn.BCEWithLogitsLoss()with torch.autograd.profiler.profile(use_cuda=True) as prof:input = torch.randn((8, 1, 128, 128)).cuda()input.requires_grad = Truetarget = torch.randint(1, (8, 1, 128, 128)).cuda().float()for i in range(10):l = loss(input, target)l.backward()print(prof.key_averages().table(sort_by="self_cpu_time_total"))
建議 9: 如果設(shè)計(jì)自定義模塊和損失——配置并測試他們
通過硬件升級可以更輕松地解決某些瓶頸。
下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
在「小白學(xué)視覺」公眾號后臺回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。
下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講 在「小白學(xué)視覺」公眾號后臺回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計(jì)數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。
下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講 在「小白學(xué)視覺」公眾號后臺回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會(huì)逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會(huì)根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會(huì)請出群,謝謝理解~
