PyTorch數(shù)據(jù)Pipeline標(biāo)準(zhǔn)化代碼模板
點(diǎn)擊上方“小白學(xué)視覺(jué)”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時(shí)間送達(dá)
PyTorch作為一款流行深度學(xué)習(xí)框架其熱度大有超越TensorFlow的感覺(jué)。根據(jù)此前的統(tǒng)計(jì),目前TensorFlow雖然仍然占據(jù)著工業(yè)界,但PyTorch在視覺(jué)和NLP領(lǐng)域的頂級(jí)會(huì)議上已呈一統(tǒng)之勢(shì)。
這篇文章筆者將和大家聚焦于PyTorch的自定義數(shù)據(jù)讀取pipeline模板和相關(guān)trciks以及如何優(yōu)化數(shù)據(jù)讀取的pipeline等。我們從PyTorch的數(shù)據(jù)對(duì)象類Dataset開(kāi)始。Dataset在PyTorch中的模塊位于utils.data下。
from torch.utils.data import Dataset本文將圍繞Dataset對(duì)象分別從原始模板、torchvision的transforms模塊、使用pandas來(lái)輔助讀取、torch內(nèi)置數(shù)據(jù)劃分功能和DataLoader來(lái)展開(kāi)闡述。
Dataset原始模板
PyTorch官方為我們提供了自定義數(shù)據(jù)讀取的標(biāo)準(zhǔn)化代碼代碼模塊,作為一個(gè)讀取框架,我們這里稱之為原始模板。其代碼結(jié)構(gòu)如下:
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, ...):# stuffdef __getitem__(self, index):# stuffreturn (img, label)def __len__(self):# return examples sizereturn count
根據(jù)這個(gè)標(biāo)準(zhǔn)化的代碼模板,我們只需要根據(jù)自己的數(shù)據(jù)讀取任務(wù),分別往__init__()、__getitem__()和__len__()三個(gè)方法里添加讀取邏輯即可。作為PyTorch范式下的數(shù)據(jù)讀取以及為了后續(xù)的data loader,三個(gè)方法缺一不可。其中:
__init__()函數(shù)用于初始化數(shù)據(jù)讀取邏輯,比如讀取包含標(biāo)簽和圖片地址的csv文件、定義transform組合等。
__getitem__()函數(shù)用來(lái)返回?cái)?shù)據(jù)和標(biāo)簽。目的上是為了能夠被后續(xù)的dataloader所調(diào)用。
__len__()函數(shù)則用于返回樣本數(shù)量。
現(xiàn)在我們往這個(gè)框架里填幾行代碼來(lái)形成一個(gè)簡(jiǎn)單的數(shù)字案例。創(chuàng)建一個(gè)從1到100的數(shù)字例子:
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self):self.samples = list(range(1, 101))def __len__(self):return len(self.samples)def __getitem__(self, idx):return self.samples[idx]if __name__ == '__main__':dataset = CustomDataset()print(len(dataset))print(dataset[50])print(dataset[1:100])

添加torchvision.transforms
然后我們來(lái)看如何從內(nèi)存中讀取數(shù)據(jù)以及如何在讀取過(guò)程中嵌入torchvision中的transforms功能。torchvision是一個(gè)獨(dú)立于torch的關(guān)于數(shù)據(jù)、模型和一些圖像增強(qiáng)操作的輔助庫(kù)。主要包括datasets默認(rèn)數(shù)據(jù)集模塊、models經(jīng)典模型模塊、transforms圖像增強(qiáng)模塊以及utils模塊等。在使用torch讀取數(shù)據(jù)的時(shí)候,一般會(huì)搭配上transforms模塊對(duì)數(shù)據(jù)進(jìn)行一些處理和增強(qiáng)工作。
添加了tranforms之后的讀取模塊可以改寫為:
from torch.utils.data import Datasetfrom torchvision import transforms as Tclass CustomDataset(Dataset):def __init__(self, ...):# stuff...# compose the transforms methodsself.transform = T.Compose([T.CenterCrop(100),T.ToTensor()])def __getitem__(self, index):# stuff...data = # Some data read from a file or image# execute the transformdata = self.transform(data)return (img, label)def __len__(self):# return examples sizereturn countif __name__ == '__main__':# Call the datasetcustom_dataset = CustomDataset(...)
可以看到,我們使用了Compose方法來(lái)把各種數(shù)據(jù)處理方法聚合到一起進(jìn)行定義數(shù)據(jù)轉(zhuǎn)換方法。通常作為初始化方法放在__init__()函數(shù)下。我們以貓狗圖像數(shù)據(jù)為例進(jìn)行說(shuō)明。

定義數(shù)據(jù)讀取方法如下:
class DogCat(Dataset):def __init__(self, root, transforms=None, train=True, val=False):"""get images and execute transforms."""self.val = valimgs = [os.path.join(root, img) for img in os.listdir(root)]# train: Cats_Dogs/trainset/cat.1.jpg# val: Cats_Dogs/valset/cat.10004.jpgimgs = sorted(imgs, key=lambda x: x.split('.')[-2])self.imgs = imgsif transforms is None:# normalizenormalize = T.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225])# trainset and valset have different data transform# trainset need data augmentation but valset don't.# valsetif self.val:self.transforms = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),normalize])# trainsetelse:self.transforms = T.Compose([T.Resize(256),T.RandomResizedCrop(224),T.RandomHorizontalFlip(),T.ToTensor(),normalize])def __getitem__(self, index):"""return data and label"""img_path = self.imgs[index]label = 1 if 'dog' in img_path.split('/')[-1] else 0data = Image.open(img_path)data = self.transforms(data)return data, labeldef __len__(self):"""return images size."""return len(self.imgs)if __name__ == "__main__":train_dataset = DogCat('./Cats_Dogs/trainset/', train=True)print(len(train_dataset))print(train_dataset[0])
因?yàn)檫@個(gè)數(shù)據(jù)集已經(jīng)分好了訓(xùn)練集和驗(yàn)證集,所以在讀取和transforms的時(shí)候需要進(jìn)行區(qū)分。運(yùn)行示例如下:

與pandas一起使用
很多時(shí)候數(shù)據(jù)的目錄地址和標(biāo)簽都是通過(guò)csv文件給出的。如下所示:

此時(shí)在數(shù)據(jù)讀取的pipeline中我們需要在__init__()方法中利用pandas把csv文件中包含的圖片地址和標(biāo)簽融合進(jìn)去。相應(yīng)的數(shù)據(jù)讀取pipeline模板可以改寫為:
class CustomDatasetFromCSV(Dataset):def __init__(self, csv_path):"""Args:csv_path (string): path to csv filetransform: pytorch transforms for transforms and tensor conversion"""# Transformsself.to_tensor = transforms.ToTensor()# Read the csv fileself.data_info = pd.read_csv(csv_path, header=None)# First column contains the image pathsself.image_arr = np.asarray(self.data_info.iloc[:, 0])# Second column is the labelsself.label_arr = np.asarray(self.data_info.iloc[:, 1])# Calculate lenself.data_len = len(self.data_info.index)def __getitem__(self, index):# Get image name from the pandas dfsingle_image_name = self.image_arr[index]# Open imageimg_as_img = Image.open(single_image_name)# Transform image to tensorimg_as_tensor = self.to_tensor(img_as_img)# Get label of the image based on the cropped pandas columnsingle_image_label = self.label_arr[index]return (img_as_tensor, single_image_label)def __len__(self):return self.data_lenif __name__ == "__main__":# Call datasetdataset = CustomDatasetFromCSV('./labels.csv')
以mnist_label.csv文件為示例:
from torch.utils.data import Datasetfrom torch.utils.data import DataLoaderfrom torchvision import transforms as Tfrom PIL import Imageimport osimport numpy as npimport pandas as pdclass CustomDatasetFromCSV(Dataset):def __init__(self, csv_path):"""Args:csv_path (string): path to csv filetransform: pytorch transforms for transforms and tensor conversion"""# Transformsself.to_tensor = T.ToTensor()# Read the csv fileself.data_info = pd.read_csv(csv_path, header=None)# First column contains the image pathsself.image_arr = np.asarray(self.data_info.iloc[:, 0])# Second column is the labelsself.label_arr = np.asarray(self.data_info.iloc[:, 1])# Third column is for an operation indicatorself.operation_arr = np.asarray(self.data_info.iloc[:, 2])# Calculate lenself.data_len = len(self.data_info.index)def __getitem__(self, index):# Get image name from the pandas dfsingle_image_name = self.image_arr[index]# Open imageimg_as_img = Image.open(single_image_name)# Check if there is an operationsome_operation = self.operation_arr[index]# If there is an operationif some_operation:# Do some operation on image# ...# ...pass# Transform image to tensorimg_as_tensor = self.to_tensor(img_as_img)# Get label of the image based on the cropped pandas columnsingle_image_label = self.label_arr[index]return (img_as_tensor, single_image_label)def __len__(self):return self.data_lenif __name__ == "__main__":transform = T.Compose([T.ToTensor()])dataset = CustomDatasetFromCSV('./mnist_labels.csv')print(len(dataset))print(dataset[5])
運(yùn)行示例如下:

訓(xùn)練集驗(yàn)證集劃分
一般來(lái)說(shuō),為了模型訓(xùn)練的穩(wěn)定,我們需要對(duì)數(shù)據(jù)劃分訓(xùn)練集和驗(yàn)證集。torch的Dataset對(duì)象也提供了random_split函數(shù)作為數(shù)據(jù)劃分工具,且劃分結(jié)果可直接供后續(xù)的DataLoader使用。
以kaggle的花朵數(shù)據(jù)為例:
from torch.utils.data import DataLoaderfrom torchvision.datasets import ImageFolderfrom torchvision import transforms as Tfrom torch.utils.data import random_splittransform = T.Compose([T.Resize((224, 224)),T.RandomHorizontalFlip(),T.ToTensor()])dataset = ImageFolder('./flowers_photos', transform=transform)print(dataset.class_to_idx)trainset, valset = random_split(dataset,[int(len(dataset)*0.7), len(dataset)-int(len(dataset)*0.7)])trainloader = DataLoader(dataset=trainset, batch_size=32, shuffle=True, num_workers=1)for i, (img, label) in enumerate(trainloader):img, label = img.numpy(), label.numpy()print(img, label)valloader = DataLoader(dataset=valset, batch_size=32, shuffle=True, num_workers=1)for i, (img, label) in enumerate(trainloader):img, label = img.numpy(), label.numpy()print(img.shape, label)
這里使用了ImageFolder模塊,可以直接讀取各標(biāo)簽對(duì)應(yīng)的文件夾,部分運(yùn)行示例如下:

使用DataLoader
dataset方法寫好之后,我們還需要使用DataLoader將其逐個(gè)喂給模型。上一節(jié)的數(shù)據(jù)劃分我們已經(jīng)用到了DataLoader函數(shù)。從本質(zhì)上來(lái)講,DataLoader只是調(diào)用了__getitem__()方法并按批次返回?cái)?shù)據(jù)和標(biāo)簽。使用方法如下:
from torch.utils.data import DataLoaderfrom torchvision import transforms as Tif __name__ == "__main__":# Define transformstransformations = T.Compose([T.ToTensor()])# Define custom datasetdataset = CustomDatasetFromCSV('./labels.csv')# Define data loaderdata_loader = DataLoader(dataset=dataset, batch_size=10, shuffle=True)for images, labels in data_loader:# Feed the data to the model
以上就是PyTorch讀取數(shù)據(jù)的Pipeline主要方法和流程。基于Dataset對(duì)象的基本框架不變,具體細(xì)節(jié)可自定義化調(diào)整。
好消息!
小白學(xué)視覺(jué)知識(shí)星球
開(kāi)始面向外開(kāi)放啦??????
下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程 在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺(jué)、目標(biāo)跟蹤、生物視覺(jué)、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目52講 在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):Python視覺(jué)實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測(cè)、車道線檢測(cè)、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測(cè)、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺(jué)實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺(jué)。 下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講 在「小白學(xué)視覺(jué)」公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。 交流群
歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺(jué)、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測(cè)、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競(jìng)賽等微信群(以后會(huì)逐漸細(xì)分),請(qǐng)掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺(jué)SLAM“。請(qǐng)按照格式備注,否則不予通過(guò)。添加成功后會(huì)根據(jù)研究方向邀請(qǐng)進(jìn)入相關(guān)微信群。請(qǐng)勿在群內(nèi)發(fā)送廣告,否則會(huì)請(qǐng)出群,謝謝理解~

