1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        使用PyTorch實現(xiàn)對花朵的分類

        共 5602字,需瀏覽 12分鐘

         ·

        2020-10-07 08:21


        PyTorch是一個非常適合初學(xué)者的高度可靠且強(qiáng)大的機(jī)器學(xué)習(xí)庫。自2016年10月以來,它已經(jīng)開源并由Facebook維護(hù),并被開發(fā)人員用于研究其原型,以部署最先進(jìn)的深度學(xué)習(xí)應(yīng)用程序。與TensorFlow等其他機(jī)器學(xué)習(xí)庫相比,PyTorch更加直觀,并具有實現(xiàn)模型的Python方式。

        決定要分類什么?


        識別花朵的類型需要某種形式關(guān)于花朵的知識,人必須事先看過花朵才能識別花朵。同樣,對于計算機(jī),很難對算法進(jìn)行硬編碼以識別花朵的類型。到目前為止,機(jī)器學(xué)習(xí)是從給定的大量花朵圖片中識別花朵名稱的唯一選擇。這使得使用深度學(xué)習(xí)實現(xiàn)花識別任務(wù)對于每個初學(xué)者來說都非常有趣。



        花朵識別數(shù)據(jù)集對于像我這樣的初學(xué)者而言,是一個很好的數(shù)據(jù)集,可用于實施和練習(xí)各種機(jī)器學(xué)習(xí)模型。

        使用什么數(shù)據(jù)集?


        我們將使用Kaggle上可用的花朵識別數(shù)據(jù)集。數(shù)據(jù)集鏈接:https ://www.kaggle.com/alxmamaev/flowers-recognition


        預(yù)處理數(shù)據(jù)集

        我們將使用神經(jīng)網(wǎng)絡(luò)對花朵進(jìn)行分類。神經(jīng)網(wǎng)絡(luò)是深度學(xué)習(xí)的一種形式,最適合當(dāng)今的圖像分類。我們首先導(dǎo)入所有需要的模塊以運(yùn)行我們的代碼。

        import numpy as np # linear algebraimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)import osimport torchimport torchvisionfrom torchvision.datasets.utils import download_urlfrom torch.utils.data import random_splitfrom torchvision.datasets import ImageFolderfrom torchvision import transformsfrom torchvision.transforms import ToTensorfrom torch.utils.data.dataloader import DataLoaderimport torch.nn as nnimport torch.nn.functional as F


        我們導(dǎo)入了PyTorch的組件以及NumPy和Pandas等數(shù)據(jù)科學(xué)庫。圖片是非結(jié)構(gòu)化數(shù)據(jù),為了將其輸入到我們的深度學(xué)習(xí)模型中,我們必須將其轉(zhuǎn)換為張量。我們需要對圖像進(jìn)行預(yù)處理,然后才能為模型做好準(zhǔn)備。我們首先使用ImageFolder 存在于torchvision.datasets 準(zhǔn)備數(shù)據(jù)集。ImageFolder是一個非常有用的工具當(dāng)圖像存儲在不同的文件夾中,其中每個文件夾都充當(dāng)類名。PyTorch還具有其他更簡單的準(zhǔn)備數(shù)據(jù)集的方式,我們可以在其中準(zhǔn)備自己的自定義數(shù)據(jù)集。

        transformer = torchvision.transforms.Compose(    [  # Applying Augmentation        torchvision.transforms.Resize((224, 224)),        torchvision.transforms.RandomHorizontalFlip(p=0.5),        torchvision.transforms.RandomVerticalFlip(p=0.5),        torchvision.transforms.RandomRotation(30),        torchvision.transforms.ToTensor(),        torchvision.transforms.Normalize(            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]        ),    ])dataset = ImageFolder(base_dir, transform=transformer)


        我們還習(xí)慣于transforms.Compose將圖像轉(zhuǎn)換為張量并應(yīng)用其他圖像增強(qiáng)技術(shù)。此外,在將各種圖像加載到數(shù)據(jù)集時,請閱讀各種變換技術(shù)并應(yīng)用于圖像。我們應(yīng)該使序加載圖像,以便可以每次分批添加數(shù)據(jù)集,并且可以優(yōu)化效率。


        定義模型


        我們可以使用從PyTorch類繼承的類來定義深度學(xué)習(xí)模型的框架?nn.Module.

        def accuracy(outputs, labels):    _, preds = torch.max(outputs, dim=1)    return torch.tensor(torch.sum(preds == labels).item() / len(preds))
        class ImageClassificationModel(nn.Module): def training_step(self, batch): images, labels = batch out = self(images) # Generate predictions loss = F.cross_entropy(out, labels) # Calculate loss return loss def __init__(self): super().__init__() self.network = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
        nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2, 2), # output: 256 x 4 x 4
        nn.Flatten(), nn.Linear(256*28*28, 1024), nn.ReLU(), nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 5)) def forward(self, xb): return self.network(xb) def validation_step(self, batch): images, labels = batch out = self(images) # Generate predictions loss = F.cross_entropy(out, labels) # Calculate loss acc = accuracy(out, labels) # Calculate accuracy return {'val_loss': loss.detach(), 'val_acc': acc} def validation_epoch_end(self, outputs): batch_losses = [x['val_loss'] for x in outputs] epoch_loss = torch.stack(batch_losses).mean() # Combine losses batch_accs = [x['val_acc'] for x in outputs] epoch_acc = torch.stack(batch_accs).mean() # Combine accuracies return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()} def epoch_end(self, epoch, result): print("Epoch [{}], train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(????????????epoch,?result['train_loss'],?result['val_loss'],?result['val_ac']))

        訓(xùn)練模型


        首先訓(xùn)練模型,讓我們將超參數(shù)設(shè)置為:

        num_epochs = 10opt_func = torch.optim.Adamlr = 0.001


        現(xiàn)在,在將模型運(yùn)行10個epoach后,我們可以看到使用基本的卷積神經(jīng)網(wǎng)絡(luò)(CNN)模型達(dá)到了約65%。


        測試模型

        65%是一個很好的結(jié)果,因為我以前曾嘗試過使用帶有一些隱藏層的簡單神經(jīng)網(wǎng)絡(luò)(NN),結(jié)果僅為40%左右。因此,CNN非常適合對圖像進(jìn)行分類,因為它們有比其他形式的機(jī)器學(xué)習(xí)更好的檢測模式。

        使用轉(zhuǎn)移學(xué)習(xí)


        現(xiàn)在讓我們再次嘗試使用已經(jīng)定義的模型(如Resnet-18)進(jìn)行轉(zhuǎn)移學(xué)習(xí),以改善模型的預(yù)測。使用相同的超參數(shù)集,我們的測試集中可以達(dá)到82%左右,這是非常令人印象深刻的。如果我們使用其他更好的CNN架構(gòu),例如Resnet50,Inception V3等,則可以進(jìn)一步改善結(jié)果。


        plot_accuracies(history)

        保存模型

        訓(xùn)練完成后,我們必須保存我們的模型,以便我們可以使用它來根據(jù)模型生成預(yù)測,甚至將來可以進(jìn)行更多訓(xùn)練。

        weights_fname = 'flower-resnet.pth'torch.save(model.state_dict(), weights_fname)
        產(chǎn)生預(yù)測

        每個機(jī)器學(xué)習(xí)周期的目標(biāo)是創(chuàng)建一個可被用于對常規(guī)數(shù)據(jù)進(jìn)行分類的模型。這可以通過幾行python代碼為最終用戶實現(xiàn)模型。

        def predict_image(img, model):    # Convert to a batch of 1    xb = to_device(img.unsqueeze(0), device)    # Get predictions from model    yb = model(xb)    # Pick index with highest probability    _, preds  = torch.max(yb, dim=1)    # Retrieve the class label    return dataset.classes[preds[0].item()]
        img, label = test_ds[2]plt.imshow(img.permute(1, 2, 0))print('Label:', dataset.classes[label], ', Predicted:', predict_image(img, model))Label: sunflower , Predicted: sunflower

        我們還可以使用服務(wù)器上的模型來識別花朵的類型。該模型可以輕松部署在服務(wù)器上,以供最終用戶識別不同類型的花朵。


        ·? END? ·


        RECOMMEND

        推薦閱讀

        ?1.?深度學(xué)習(xí)——入門PyTorch(一)

        ?2.?深度學(xué)習(xí)——入門PyTorch(二)

        ?3. PyTorch入門——autograd(一)

        ?4.?PyTorch入門——autograd(二)

        ?5.?PyTorch入門——autograd(三)

        瀏覽 65
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報
        評論
        圖片
        表情
        推薦
        點(diǎn)贊
        評論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            亚洲人妖操逼 | 美女免费网站 | 抽插特写豆花视频 | 少妇的诱惑播放 | 无遮无挡试看120秒动态图 | A片免费在线播放 | 五月天婷婷小说 | 爱爱爱爽爽爽 | aigao无码精品网站 | 中文字幕亚洲视频 |