1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        使用PyTorch進(jìn)行小樣本學(xué)習(xí)的圖像分類(lèi)

        共 7219字,需瀏覽 15分鐘

         ·

        2022-11-18 20:58


        來(lái)源:DeepHub IMBA

        本文約3600字,建議閱讀9分鐘
        本文手把手教你小樣本學(xué)習(xí)方法。

        近年來(lái),基于深度學(xué)習(xí)的模型在目標(biāo)檢測(cè)和圖像識(shí)別等任務(wù)中表現(xiàn)出色。像ImageNet這樣具有挑戰(zhàn)性的圖像分類(lèi)數(shù)據(jù)集,包含1000種不同的對(duì)象分類(lèi),現(xiàn)在一些模型已經(jīng)超過(guò)了人類(lèi)水平上。但是這些模型依賴于監(jiān)督訓(xùn)練流程,標(biāo)記訓(xùn)練數(shù)據(jù)的可用性對(duì)它們有重大影響,并且模型能夠檢測(cè)到的類(lèi)別也僅限于它們接受訓(xùn)練的類(lèi)。

        由于在訓(xùn)練過(guò)程中沒(méi)有足夠的標(biāo)記圖像用于所有類(lèi),這些模型在現(xiàn)實(shí)環(huán)境中可能不太有用。并且我們希望的模型能夠識(shí)別它在訓(xùn)練期間沒(méi)有見(jiàn)到過(guò)的類(lèi),因?yàn)閹缀醪豢赡茉谒袧撛趯?duì)象的圖像上進(jìn)行訓(xùn)練。我們將從幾個(gè)樣本中學(xué)習(xí)的問(wèn)題被稱為“少樣本學(xué)習(xí) Few-Shot learning”。

        什么是小樣本學(xué)習(xí)?



        少樣本學(xué)習(xí)是機(jī)器學(xué)習(xí)的一個(gè)子領(lǐng)域。它涉及到在只有少數(shù)訓(xùn)練樣本和監(jiān)督數(shù)據(jù)的情況下對(duì)新數(shù)據(jù)進(jìn)行分類(lèi)。只需少量的訓(xùn)練樣本,我們創(chuàng)建的模型就可以相當(dāng)好地執(zhí)行。

        考慮以下場(chǎng)景:在醫(yī)療領(lǐng)域,對(duì)于一些不常見(jiàn)的疾病,可能沒(méi)有足夠的x光圖像用于訓(xùn)練。對(duì)于這樣的場(chǎng)景,構(gòu)建一個(gè)小樣本學(xué)習(xí)分類(lèi)器是完美的解決方案。

        小樣本的變化

        一般來(lái)說(shuō),研究人員確定了四種類(lèi)型:

        1. N-Shot Learning (NSL)
        2. Few-Shot Learning ( FSL )
        3. One-Shot Learning (OSL)
        4. Zero-Shot Learning (ZSL)

        當(dāng)我們談?wù)?FSL 時(shí),我們通常指的是 N-way-K-Shot 分類(lèi)。N 代表類(lèi)別數(shù),K 代表每個(gè)類(lèi)中要訓(xùn)練的樣本數(shù)。所以N-Shot Learning 被視為比所有其他概念更廣泛的概念??梢哉f(shuō) Few-Shot、One-Shot 和 Zero-Shot是 NSL 的子領(lǐng)域。而零樣本學(xué)習(xí)旨在在沒(méi)有任何訓(xùn)練示例的情況下對(duì)看不見(jiàn)的類(lèi)進(jìn)行分類(lèi)。

        在 One-Shot Learning 中,每個(gè)類(lèi)只有一個(gè)樣本。Few-Shot 每個(gè)類(lèi)有 2 到 5 個(gè)樣本,也就是說(shuō) Few-Shot 是更靈活的 One-Shot Learning 版本。


        小樣本學(xué)習(xí)方法


        通常,在解決 Few Shot Learning 問(wèn)題時(shí)應(yīng)考慮兩種方法:

        數(shù)據(jù)級(jí)方法 (DLA)

        這個(gè)策略非常簡(jiǎn)單,如果沒(méi)有足夠的數(shù)據(jù)來(lái)創(chuàng)建實(shí)體模型并防止欠擬合和過(guò)擬合,那么就應(yīng)該添加更多數(shù)據(jù)。正因?yàn)槿绱?,許多 FSL 問(wèn)題都可以通過(guò)利用來(lái)更大大的基礎(chǔ)數(shù)據(jù)集的更多數(shù)據(jù)來(lái)解決?;緮?shù)據(jù)集的顯著特征是它缺少構(gòu)成我們對(duì) Few-Shot 挑戰(zhàn)的支持集的類(lèi)。例如,如果我們想要對(duì)某種鳥(niǎo)類(lèi)進(jìn)行分類(lèi),則基礎(chǔ)數(shù)據(jù)集可能包含許多其他鳥(niǎo)類(lèi)的圖片。

        參數(shù)級(jí)方法 (PLA)

        從參數(shù)級(jí)別的角度來(lái)看,F(xiàn)ew-Shot Learning 樣本相對(duì)容易過(guò)擬合,因?yàn)樗鼈兺ǔ>哂写蟮母呔S空間。限制參數(shù)空間、使用正則化和使用適當(dāng)?shù)膿p失函數(shù)將有助于解決這個(gè)問(wèn)題。少量的訓(xùn)練樣本將被模型泛化。

        通過(guò)將模型引導(dǎo)到廣闊的參數(shù)空間可以提高性能。由于缺乏訓(xùn)練數(shù)據(jù),正常的優(yōu)化方法可能無(wú)法產(chǎn)生準(zhǔn)確的結(jié)果。

        因?yàn)樯厦娴脑?,?xùn)練我們的模型以發(fā)現(xiàn)通過(guò)參數(shù)空間的最佳路徑,產(chǎn)生最佳的預(yù)測(cè)結(jié)果。這種方法被稱為元學(xué)習(xí)。

        小樣本學(xué)習(xí)圖像分類(lèi)算法


        有4種比較常見(jiàn)的小樣本學(xué)習(xí)的方法:

        與模型無(wú)關(guān)的元學(xué)習(xí) Model-Agnostic Meta-Learning

        基于梯度的元學(xué)習(xí) (GBML) 原則是 MAML 的基礎(chǔ)。在 GBML 中,元學(xué)習(xí)者通過(guò)基礎(chǔ)模型訓(xùn)練和學(xué)習(xí)所有任務(wù)表示的共享特征來(lái)獲得先前的經(jīng)驗(yàn)。每次有新任務(wù)要學(xué)習(xí)時(shí),元學(xué)習(xí)器都會(huì)利用其現(xiàn)有經(jīng)驗(yàn)和新任務(wù)提供的最少量的新訓(xùn)練數(shù)據(jù)進(jìn)行微調(diào)訓(xùn)練。

        一般情況下,如果我們隨機(jī)初始化參數(shù)經(jīng)過(guò)幾次更新算法將不會(huì)收斂到良好的性能。MAML 試圖解決這個(gè)問(wèn)題。MAML 只需幾個(gè)梯度步驟并且保證沒(méi)有過(guò)度擬合的前提下,為元參數(shù)學(xué)習(xí)器提供了可靠的初始化,這樣可以對(duì)新任務(wù)進(jìn)行最佳快速學(xué)習(xí)。

        步驟如下:

        1. 元學(xué)習(xí)者在每個(gè)分集(episode)開(kāi)始時(shí)創(chuàng)建自己的副本C,
        2. C 在這一分集上進(jìn)行訓(xùn)練(在 base-model 的幫助下),
        3. C 對(duì)查詢集進(jìn)行預(yù)測(cè),
        4. 從這些預(yù)測(cè)中計(jì)算出的損失用于更新 C,
        5. 這種情況一直持續(xù)到完成所有分集的訓(xùn)練。

        這種技術(shù)的最大優(yōu)勢(shì)在于,它被認(rèn)為與元學(xué)習(xí)算法的選擇無(wú)關(guān)。因此MAML 方法被廣泛用于許多需要快速適應(yīng)的機(jī)器學(xué)習(xí)算法,尤其是深度神經(jīng)網(wǎng)絡(luò)。

        匹配網(wǎng)絡(luò) Matching Networks

        為解決 FSL 問(wèn)題而創(chuàng)建的第一個(gè)度量學(xué)習(xí)方法是匹配網(wǎng)絡(luò) (MN)。

        當(dāng)使用匹配網(wǎng)絡(luò)方法解決 Few-Shot Learning 問(wèn)題時(shí)需要一個(gè)大的基礎(chǔ)數(shù)據(jù)集。

        將該數(shù)據(jù)集分為幾個(gè)分集之后,對(duì)于每一分集,匹配網(wǎng)絡(luò)進(jìn)行以下操作:

        • 來(lái)自支持集和查詢集的每個(gè)圖像都被饋送到一個(gè) CNN,該 CNN 為它們輸出特征的嵌入
        • 查詢圖像使用支持集訓(xùn)練的模型得到嵌入特征的余弦距離,通過(guò) softmax 進(jìn)行分類(lèi)
        • 分類(lèi)結(jié)果的交叉熵?fù)p失通過(guò) CNN 反向傳播更新特征嵌入模型

        匹配網(wǎng)絡(luò)可以通過(guò)這種方式學(xué)習(xí)構(gòu)建圖像嵌入。MN 能夠使用這種方法對(duì)照片進(jìn)行分類(lèi),并且無(wú)需任何特殊的類(lèi)別先驗(yàn)知識(shí)。他只要簡(jiǎn)單地比較類(lèi)的幾個(gè)實(shí)例就可以了。

        由于類(lèi)別因分集而異,因此匹配網(wǎng)絡(luò)會(huì)計(jì)算對(duì)類(lèi)別區(qū)分很重要的圖片屬性(特征)。而當(dāng)使用標(biāo)準(zhǔn)分類(lèi)時(shí),算法會(huì)選擇每個(gè)類(lèi)別獨(dú)有的特征。

        原型網(wǎng)絡(luò) Prototypical Networks

        與匹配網(wǎng)絡(luò)類(lèi)似的是原型網(wǎng)絡(luò)(PN)。它通過(guò)一些細(xì)微的變化來(lái)提高算法的性能。PN 比 MN 取得了更好的結(jié)果,但它們訓(xùn)練過(guò)程本質(zhì)上是相同的,只是比較了來(lái)自支持集的一些查詢圖片嵌入,但是原型網(wǎng)絡(luò)提供了不同的策略。

        我們需要在 PN 中創(chuàng)建類(lèi)的原型:通過(guò)對(duì)類(lèi)中圖像的嵌入進(jìn)行平均而創(chuàng)建的類(lèi)的嵌入。然后僅使用這些類(lèi)原型來(lái)比較查詢圖像嵌入。當(dāng)用于單樣本學(xué)習(xí)問(wèn)題時(shí),它可與匹配網(wǎng)絡(luò)相媲美。

        關(guān)系網(wǎng)絡(luò) Relation Network

        關(guān)系網(wǎng)絡(luò)可以說(shuō)繼承了所有上面提到方法的研究的結(jié)果。RN是基于PN思想的但包含了顯著的算法改進(jìn)。

        該方法使用的距離函數(shù)是可學(xué)習(xí)的,而不是像以前研究的事先定義它。 關(guān)系模塊位于嵌入模塊之上,嵌入模塊是從輸入圖像計(jì)算嵌入和類(lèi)原型的部分。

        可訓(xùn)練的關(guān)系模塊(距離函數(shù))輸入是查詢圖像的嵌入與每個(gè)類(lèi)的原型,輸出為每個(gè)分類(lèi)匹配的關(guān)系分?jǐn)?shù)。關(guān)系分?jǐn)?shù)通過(guò) Softmax 得到一個(gè)預(yù)測(cè)。


        使用 Open-AI Clip 進(jìn)行零樣本學(xué)習(xí)


        CLIP(Contrastive Language-Image Pre-Training)是一個(gè)在各種(圖像、文本)對(duì)上訓(xùn)練的神經(jīng)網(wǎng)絡(luò)。它無(wú)需直接針對(duì)任務(wù)進(jìn)行優(yōu)化,就可以為給定的圖像來(lái)預(yù)測(cè)最相關(guān)的文本片段(類(lèi)似于 GPT-2 和 3 的零樣本的功能)。

        CLIP 在 ImageNet“零樣本”上可以達(dá)到原始 ResNet50 的性能,而且需要不使用任何標(biāo)記示例,它克服了計(jì)算機(jī)視覺(jué)中的幾個(gè)主要挑戰(zhàn),下面我們使用Pytorch來(lái)實(shí)現(xiàn)一個(gè)簡(jiǎn)單的分類(lèi)模型。

        引入包

        ! pip install ftfy regex tqdm ! pip install git+https://github.com/openai/CLIP.gitimport numpy as np import torch from pkg_resources import packaging
        print("Torch version:", torch.__version__)


        加載模型

        import clipclip.available_models() # it will list the names of available CLIP modelsmodel, preprocess = clip.load("ViT-B/32") model.cuda().eval() input_resolution = model.visual.input_resolution context_length = model.context_length vocab_size = model.vocab_size
        print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}") print("Input resolution:", input_resolution) print("Context length:", context_length) print("Vocab size:", vocab_size)


        圖像預(yù)處理

        我們將向模型輸入8個(gè)示例圖像及其文本描述,并比較對(duì)應(yīng)特征之間的相似性。

        分詞器不區(qū)分大小寫(xiě),我們可以自由地給出任何合適的文本描述。

        import os import skimage import IPython.display import matplotlib.pyplot as plt from PIL import Image import numpy as np
        from collections import OrderedDict import torch
        %matplotlib inline %config InlineBackend.figure_format = 'retina'
        # images in skimage to use and their textual descriptions descriptions = { "page": "a page of text about segmentation", "chelsea": "a facial photo of a tabby cat", "astronaut": "a portrait of an astronaut with the American flag", "rocket": "a rocket standing on a launchpad", "motorcycle_right": "a red motorcycle standing in a garage", "camera": "a person looking at a camera on a tripod", "horse": "a black-and-white silhouette of a horse", "coffee": "a cup of coffee on a saucer" }original_images = [] images = [] texts = [] plt.figure(figsize=(16, 5))
        for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]: name = os.path.splitext(filename)[0] if name not in descriptions: continue
        image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
        plt.subplot(2, 4, len(images) + 1) plt.imshow(image) plt.title(f"{filename}\n{descriptions[name]}") plt.xticks([]) plt.yticks([])
        original_images.append(image) images.append(preprocess(image)) texts.append(descriptions[name])
        plt.tight_layout()


        結(jié)果的可視化如下:


        我們對(duì)圖像進(jìn)行規(guī)范化,對(duì)每個(gè)文本輸入進(jìn)行標(biāo)記,并運(yùn)行模型的正傳播獲得圖像和文本的特征。

        image_input = torch.tensor(np.stack(images)).cuda() text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
        with torch.no_grad(): image_features = model.encode_image(image_input).float() text_features = model.encode_text(text_tokens).float()

        我們將特征歸一化,并計(jì)算每一對(duì)的點(diǎn)積,進(jìn)行余弦相似度計(jì)算。

        image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
        count = len(descriptions)
        plt.figure(figsize=(20, 14)) plt.imshow(similarity, vmin=0.1, vmax=0.3) # plt.colorbar() plt.yticks(range(count), texts, fontsize=18) plt.xticks([]) for i, image in enumerate(original_images): plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower") for x in range(similarity.shape[1]): for y in range(similarity.shape[0]): plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
        for side in ["left", "top", "right", "bottom"]: plt.gca().spines[side].set_visible(False)
        plt.xlim([-0.5, count - 0.5]) plt.ylim([count + 0.5, -2])
        plt.title("Cosine similarity between text and image features", size=20)

        零樣本的圖像分類(lèi):

        from torchvision.datasets import CIFAR100 cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True) text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes] text_tokens = clip.tokenize(text_descriptions).cuda() with torch.no_grad(): text_features = model.encode_text(text_tokens).float() text_features /= text_features.norm(dim=-1, keepdim=True)
        text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) top_probs, top_labels = text_probs.cpu().topk(5, dim=-1) plt.figure(figsize=(16, 16)) for i, image in enumerate(original_images): plt.subplot(4, 4, 2 * i + 1) plt.imshow(image) plt.axis("off")
        plt.subplot(4, 4, 2 * i + 2) y = np.arange(top_probs.shape[-1]) plt.grid() plt.barh(y, top_probs[i]) plt.gca().invert_yaxis() plt.gca().set_axisbelow(True) plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()]) plt.xlabel("probability")
        plt.subplots_adjust(wspace=0.5) plt.show()

        可以看到,分類(lèi)的效果還是非常好的。

        編輯:黃繼彥

        瀏覽 47
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            亚洲欧美成人网 | 国产精品久久无码一区二区三区网 | 国产精品自拍偷怕 | 久久88 | 国产又大又长又 | 人妖精品人妖TS视频在线观看 | 18禁国产 | 丰满少妇18p | 成人在线观看国产 | 西方美术人文艺术 |