1. 利用PyTorch訓(xùn)練一個(gè)CNN分類器

        共 12117字,需瀏覽 25分鐘

         ·

        2022-12-30 21:48

        點(diǎn)擊上方小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂

        重磅干貨,第一時(shí)間送達(dá)

        前言

        原文翻譯自:Deep Learning with PyTorch: A 60 Minute Blitz

        翻譯:林不清(https://www.zhihu.com/people/lu-guo-92-42-88)

        目錄

        訓(xùn)練一個(gè)分類器

        你已經(jīng)學(xué)會(huì)如何去定義一個(gè)神經(jīng)網(wǎng)絡(luò),計(jì)算損失值和更新網(wǎng)絡(luò)的權(quán)重。

        你現(xiàn)在可能在思考:數(shù)據(jù)哪里來呢?

        關(guān)于數(shù)據(jù)

        通常,當(dāng)你處理圖像,文本,音頻和視頻數(shù)據(jù)時(shí),你可以使用標(biāo)準(zhǔn)的Python包來加載數(shù)據(jù)到一個(gè)numpy數(shù)組中.然后把這個(gè)數(shù)組轉(zhuǎn)換成torch.*Tensor。

        • 對于圖像,有諸如Pillow,OpenCV包等非常實(shí)用
        • 對于音頻,有諸如scipy和librosa包
        • 對于文本,可以用原始Python和Cython來加載,或者使用NLTK和SpaCy          對于視覺,我們創(chuàng)建了一個(gè)torchvision包,包含常見數(shù)據(jù)集的數(shù)據(jù)加載,比如Imagenet,CIFAR10,MNIST等,和圖像轉(zhuǎn)換器,也就是torchvision.datasetstorch.utils.data.DataLoader

        這提供了巨大的便利,也避免了代碼的重復(fù)。

        在這個(gè)教程中,我們使用CIFAR10數(shù)據(jù)集,它有如下10個(gè)類別:’airplane’,’automobile’,’bird’,’cat’,’deer’,’dog’,’frog’,’horse’,’ship’,’truck’。這個(gè)數(shù)據(jù)集中的圖像大小為3*32*32,即,3通道,32*32像素。

        訓(xùn)練一個(gè)圖像分類器

        我們將按照下列順序進(jìn)行:

        • 使用torchvision加載和歸一化CIFAR10訓(xùn)練集和測試集.
        • 定義一個(gè)卷積神經(jīng)網(wǎng)絡(luò)
        • 定義損失函數(shù)
        • 在訓(xùn)練集上訓(xùn)練網(wǎng)絡(luò)
        • 在測試集上測試網(wǎng)絡(luò)

        1. 加載和歸一化CIFAR10

        使用torchvision加載CIFAR10是非常容易的。

        %matplotlib inline
        import torch
        import torchvision
        import torchvision.transforms as transforms

        torchvision的輸出是[0,1]的PILImage圖像,我們把它轉(zhuǎn)換為歸一化范圍為[-1, 1]的張量。

        注意

        如果在Windows上運(yùn)行時(shí)出現(xiàn)BrokenPipeError,嘗試將torch.utils.data.DataLoader()的num_worker設(shè)置為0。

        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.50.50.5), (0.50.50.5))])

        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=transform)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                                  shuffle=True, num_workers=2)

        testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                               download=True, transform=transform)
        testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                                 shuffle=False, num_workers=2)

        classes = ('plane''car''bird''cat',
                   'deer''dog''frog''horse''ship''truck')
        #這個(gè)過程有點(diǎn)慢,會(huì)下載大約340mb圖片數(shù)據(jù)。

        我們展示一些有趣的訓(xùn)練圖像。

        import matplotlib.pyplot as plt
        import numpy as np

        # functions to show an image


        def imshow(img):
            img = img / 2 + 0.5     # unnormalize
            npimg = img.numpy()
            plt.imshow(np.transpose(npimg, (120)))
            plt.show()


        # get some random training images
        dataiter = iter(trainloader)
        images, labels = dataiter.next()

        # show images
        imshow(torchvision.utils.make_grid(images))
        # print labels
        print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

        2. 定義一個(gè)卷積神經(jīng)網(wǎng)絡(luò)

        從之前的神經(jīng)網(wǎng)絡(luò)一節(jié)復(fù)制神經(jīng)網(wǎng)絡(luò)代碼,并修改為接受3通道圖像取代之前的接受單通道圖像。

        import torch.nn as nn
        import torch.nn.functional as F


        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.conv1 = nn.Conv2d(365)
                self.pool = nn.MaxPool2d(22)
                self.conv2 = nn.Conv2d(6165)
                self.fc1 = nn.Linear(16 * 5 * 5120)
                self.fc2 = nn.Linear(12084)
                self.fc3 = nn.Linear(8410)

            def forward(self, x):
                x = self.pool(F.relu(self.conv1(x)))
                x = self.pool(F.relu(self.conv2(x)))
                x = x.view(-116 * 5 * 5)
                x = F.relu(self.fc1(x))
                x = F.relu(self.fc2(x))
                x = self.fc3(x)
                return x


        net = Net()

        3. 定義損失函數(shù)和優(yōu)化器

        我們使用交叉熵作為損失函數(shù),使用帶動(dòng)量的隨機(jī)梯度下降。

        import torch.optim as optim

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

        4. 訓(xùn)練網(wǎng)絡(luò)

        這是開始有趣的時(shí)刻,我們只需在數(shù)據(jù)迭代器上循環(huán),把數(shù)據(jù)輸入給網(wǎng)絡(luò),并優(yōu)化。

        for epoch in range(2):  # loop over the dataset multiple times

            running_loss = 0.0
            for i, data in enumerate(trainloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                # print statistics
                running_loss += loss.item()
                if i % 2000 == 1999:    # print every 2000 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss / 2000))
                    running_loss = 0.0

        print('Finished Training')

        保存一下我們的訓(xùn)練模型

        PATH = './cifar_net.pth'
        torch.save(net.state_dict(), PATH)

        點(diǎn)擊這里查看關(guān)于保存模型的詳細(xì)介紹

        5. 在測試集上測試網(wǎng)絡(luò)

        我們在整個(gè)訓(xùn)練集上訓(xùn)練了兩次網(wǎng)絡(luò),但是我們還需要檢查網(wǎng)絡(luò)是否從數(shù)據(jù)集中學(xué)習(xí)到東西。

        我們通過預(yù)測神經(jīng)網(wǎng)絡(luò)輸出的類別標(biāo)簽并根據(jù)實(shí)際情況進(jìn)行檢測,如果預(yù)測正確,我們把該樣本添加到正確預(yù)測列表。

        第一步,顯示測試集中的圖片一遍熟悉圖片內(nèi)容。

        dataiter = iter(testloader)
        images, labels = dataiter.next()

        # print images
        imshow(torchvision.utils.make_grid(images))
        print('GroundTruth: '' '.join('%5s' % classes[labels[j]] for j in range(4)))

        接下來,讓我們重新加載我們保存的模型(注意:保存和重新加載模型在這里不是必要的,我們只是為了說明如何這樣做):

        net = Net()
        net.load_state_dict(torch.load(PATH))

        現(xiàn)在我們來看看神經(jīng)網(wǎng)絡(luò)認(rèn)為以上圖片是什么?

        outputs = net(images)

        輸出是10個(gè)標(biāo)簽的概率。一個(gè)類別的概率越大,神經(jīng)網(wǎng)絡(luò)越認(rèn)為他是這個(gè)類別。所以讓我們得到最高概率的標(biāo)簽。

        _, predicted = torch.max(outputs, 1)

        print('Predicted: '' '.join('%5s' % classes[predicted[j]]
                                      for j in range(4)))

        這結(jié)果看起來非常的好。

        接下來讓我們看看網(wǎng)絡(luò)在整個(gè)測試集上的結(jié)果如何。

        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print('Accuracy of the network on the 10000 test images: %d %%' % (
            100 * correct / total))

        結(jié)果看起來好于偶然,偶然的正確率為10%,似乎網(wǎng)絡(luò)學(xué)習(xí)到了一些東西。

        那在什么類上預(yù)測較好,什么類預(yù)測結(jié)果不好呢?

        class_correct = list(0. for i in range(10))
        class_total = list(0. for i in range(10))
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs, 1)
                c = (predicted == labels).squeeze()
                for i in range(4):
                    label = labels[i]
                    class_correct[label] += c[i].item()
                    class_total[label] += 1


        for i in range(10):
            print('Accuracy of %5s : %2d %%' % (
                classes[i], 100 * class_correct[i] / class_total[i]))

        接下來干什么?

        我們?nèi)绾卧贕PU上運(yùn)行神經(jīng)網(wǎng)絡(luò)呢?

        在GPU上訓(xùn)練

        你是如何把一個(gè)Tensor轉(zhuǎn)換GPU上,你就如何把一個(gè)神經(jīng)網(wǎng)絡(luò)移動(dòng)到GPU上訓(xùn)練。這個(gè)操作會(huì)遞歸遍歷有所模塊,并將其參數(shù)和緩沖區(qū)轉(zhuǎn)換為CUDA張量。

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # Assume that we are on a CUDA machine, then this should print a CUDA device:
        #假設(shè)我們有一臺(tái)CUDA的機(jī)器,這個(gè)操作將顯示CUDA設(shè)備。
        print(device)

        接下來假設(shè)我們有一臺(tái)CUDA的機(jī)器,然后這些方法將遞歸遍歷所有模塊并將其參數(shù)和緩沖區(qū)轉(zhuǎn)換為CUDA張量:

        net.to(device)

        請記住,你也必須在每一步中把你的輸入和目標(biāo)值轉(zhuǎn)換到GPU上:

        inputs, labels = inputs.to(device), labels.to(device)

        為什么我們沒注意到GPU的速度提升很多?那是因?yàn)榫W(wǎng)絡(luò)非常的小。

        實(shí)踐:

        嘗試增加你的網(wǎng)絡(luò)的寬度(第一個(gè)nn.Conv2d的第2個(gè)參數(shù), 第二個(gè)nn.Conv2d的第一個(gè)參數(shù),他們需要是相同的數(shù)字),看看你得到了什么樣的加速。

        實(shí)現(xiàn)的目標(biāo):

        • 深入了解了PyTorch的張量庫和神經(jīng)網(wǎng)絡(luò)
        • 訓(xùn)練了一個(gè)小網(wǎng)絡(luò)來分類圖片

        在多GPU上訓(xùn)練

        如果你希望使用所有GPU來更大的加快速度,請查看選讀:[數(shù)據(jù)并行]:(https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html)

        接下來做什么?

        • 訓(xùn)練神經(jīng)網(wǎng)絡(luò)玩電子游戲
        • 在ImageNet上訓(xùn)練最好的ResNet
        • 使用對抗生成網(wǎng)絡(luò)來訓(xùn)練一個(gè)人臉生成器
        • 使用LSTM網(wǎng)絡(luò)訓(xùn)練一個(gè)字符級(jí)的語言模型
        • 更多示例
        • 更多教程
        • 在論壇上討論P(yáng)yTorch
        • 在Slack上與其他用戶聊天

        好消息!

        小白學(xué)視覺知識(shí)星球

        開始面向外開放啦??????




        下載1:OpenCV-Contrib擴(kuò)展模塊中文版教程
        在「小白學(xué)視覺」公眾號(hào)后臺(tái)回復(fù):擴(kuò)展模塊中文教程即可下載全網(wǎng)第一份OpenCV擴(kuò)展模塊教程中文版,涵蓋擴(kuò)展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

        下載2:Python視覺實(shí)戰(zhàn)項(xiàng)目52講
        小白學(xué)視覺公眾號(hào)后臺(tái)回復(fù):Python視覺實(shí)戰(zhàn)項(xiàng)目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計(jì)數(shù)、添加眼線、車牌識(shí)別、字符識(shí)別、情緒檢測、文本內(nèi)容提取、面部識(shí)別等31個(gè)視覺實(shí)戰(zhàn)項(xiàng)目,助力快速學(xué)校計(jì)算機(jī)視覺。

        下載3:OpenCV實(shí)戰(zhàn)項(xiàng)目20講
        小白學(xué)視覺公眾號(hào)后臺(tái)回復(fù):OpenCV實(shí)戰(zhàn)項(xiàng)目20講,即可下載含有20個(gè)基于OpenCV實(shí)現(xiàn)20個(gè)實(shí)戰(zhàn)項(xiàng)目,實(shí)現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。

        交流群


        歡迎加入公眾號(hào)讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動(dòng)駕駛、計(jì)算攝影、檢測、分割、識(shí)別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會(huì)逐漸細(xì)分),請掃描下面微信號(hào)加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會(huì)根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會(huì)請出群,謝謝理解~


        瀏覽 45
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評論
        圖片
        表情
        推薦
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
          
          

            1. 免费成人视频 | 国产一级婬乱AⅤ大片野兰花 | 天天日天天添 | 91乱伦视频 | 久久久久久久久久久91 |