使用Transformer來(lái)做物體檢測(cè)
點(diǎn)擊左上方藍(lán)字關(guān)注我們

轉(zhuǎn)載自 | 學(xué)算法的小黑狗


介紹
DEtection TRansformer (DETR)是Facebook研究團(tuán)隊(duì)巧妙地利用了Transformer 架構(gòu)開(kāi)發(fā)的一個(gè)目標(biāo)檢測(cè)模型。在這篇文章中,我將通過(guò)分析DETR架構(gòu)的內(nèi)部工作方式來(lái)幫助提供一些關(guān)于它的直覺(jué)。
下面,我將解釋一些結(jié)構(gòu),但是如果你只是想了解如何使用模型,可以直接跳到代碼部分。

結(jié)構(gòu)
DETR模型由一個(gè)預(yù)訓(xùn)練的CNN骨干(如ResNet)組成,它產(chǎn)生一組低維特征集。這些特征被格式化為一個(gè)特征集合并添加位置編碼,輸入一個(gè)由Transformer組成的編碼器和解碼器中,和原始的Transformer論文中描述的Encoder-Decoder的使用方式非常的類(lèi)似。解碼器的輸出然后被送入固定數(shù)量的預(yù)測(cè)頭,這些預(yù)測(cè)頭由預(yù)定義數(shù)量的前饋網(wǎng)絡(luò)組成。每個(gè)預(yù)測(cè)頭的輸出都包含一個(gè)類(lèi)預(yù)測(cè)和一個(gè)預(yù)測(cè)框。損失是通過(guò)計(jì)算二分匹配損失來(lái)計(jì)算的。

該模型做出了預(yù)定義數(shù)量的預(yù)測(cè),并且每個(gè)預(yù)測(cè)都是并行計(jì)算的。

CNN主干
假設(shè)我們的輸入圖像,有三個(gè)輸入通道。CNN backbone由一個(gè)(預(yù)訓(xùn)練過(guò)的)CNN(通常是ResNet)組成,我們用它來(lái)生成C個(gè)具有寬度W和高度H的低維特征(在實(shí)踐中,我們?cè)O(shè)置C=2048, W=W?/32和H=H?/32)。
這留給我們的是C個(gè)二維特征,由于我們將把這些特征傳遞給一個(gè)transformer,每個(gè)特征必須允許編碼器將每個(gè)特征處理為一個(gè)序列的方式重新格式化。這是通過(guò)將特征矩陣扁平化為H?W向量,然后將每個(gè)向量連接起來(lái)來(lái)實(shí)現(xiàn)的。

扁平化的卷積特征再加上空間位置編碼,位置編碼既可以學(xué)習(xí),也可以預(yù)定義。

The Transformer
Transformer幾乎與原始的編碼器-解碼器架構(gòu)完全相同。不同之處在于,每個(gè)解碼器層并行解碼N個(gè)(預(yù)定義的數(shù)目)目標(biāo)。該模型還學(xué)習(xí)了一組N個(gè)目標(biāo)的查詢(xún),這些查詢(xún)是(類(lèi)似于編碼器)學(xué)習(xí)出來(lái)的位置編碼。


目標(biāo)查詢(xún)
下圖描述了N=20個(gè)學(xué)習(xí)出來(lái)的目標(biāo)查詢(xún)(稱(chēng)為prediction slots)如何聚焦于一張圖像的不同區(qū)域。

“我們觀(guān)察到,在不同的操作模式下,每個(gè)slot 都會(huì)學(xué)習(xí)特定的區(qū)域和框大小?!?—— DETR的作者
理解目標(biāo)查詢(xún)的直觀(guān)方法是想象每個(gè)目標(biāo)查詢(xún)都是一個(gè)人。每個(gè)人都可以通過(guò)注意力來(lái)查看圖像的某個(gè)區(qū)域。一個(gè)目標(biāo)查詢(xún)總是會(huì)問(wèn)圖像中心是什么,另一個(gè)總是會(huì)問(wèn)左下角是什么,以此類(lèi)推。

使用PyTorch實(shí)現(xiàn)簡(jiǎn)單的DETR
import torch
import torch.nn as nn
from torchvision.models import resnet50
class SimpleDETR(nn.Module):
"""
Minimal Example of the Detection Transformer model with learned positional embedding
"""
def __init__(self, num_classes, hidden_dim, num_heads,
num_enc_layers, num_dec_layers):
super(SimpleDETR, self).__init__()
self.num_classes = num_classes
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.num_enc_layers = num_enc_layers
self.num_dec_layers = num_dec_layers
# CNN Backbone
self.backbone = nn.Sequential(
*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
# Transformer
self.transformer = nn.Transformer(hidden_dim, num_heads,
num_enc_layers, num_dec_layers)
# Prediction Heads
self.to_classes = nn.Linear(hidden_dim, num_classes+1)
self.to_bbox = nn.Linear(hidden_dim, 4)
# Positional Encodings
self.object_query = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2)
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, X):
X = self.backbone(X)
h = self.conv(X)
H, W = h.shape[-2:]
pos_enc = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H,1,1),
self.row_embed[:H].unsqueeze(1).repeat(1,W,1)],
dim=-1).flatten(0,1).unsqueeze(1)
h = self.transformer(pos_enc + h.flatten(2).permute(2,0,1),
self.object_query.unsqueeze(1))
class_pred = self.to_classes(h)
bbox_pred = self.to_bbox(h).sigmoid()
return class_pred, bbox_pred

二分匹配損失 (Optional)


框損失的計(jì)算為預(yù)測(cè)值與ground truth的L?損失和的GIOU損失的線(xiàn)性組合。同樣,如果你想象兩個(gè)不相交的框,那么框的錯(cuò)誤將不會(huì)提供任何有意義的上下文(我們可以從下面的框損失的定義中看到)。

可以把上面的等式看作是與預(yù)測(cè)相關(guān)聯(lián)的總損失,其中面積誤差的重要性是λ???,距離誤差的重要性是。

由于我們從已知的已知類(lèi)的數(shù)目來(lái)預(yù)測(cè)類(lèi),那么類(lèi)預(yù)測(cè)就是一個(gè)分類(lèi)問(wèn)題,因此我們可以使用交叉熵?fù)p失來(lái)計(jì)算類(lèi)預(yù)測(cè)誤差。我們將損失函數(shù)定義為每N個(gè)預(yù)測(cè)損失的總和:


為目標(biāo)檢測(cè)使用DETR
在這里,你可以學(xué)習(xí)如何加載預(yù)訓(xùn)練的DETR模型,以便使用PyTorch進(jìn)行目標(biāo)檢測(cè)。
8.1 加載模型
首先導(dǎo)入需要的模塊。
# Import required modules
import torch
from torchvision import transforms as T import requests # for loading images from web
from PIL import Image # for viewing images
import matplotlib.pyplot as plt
下面的代碼用ResNet50作為CNN骨干從torch hub加載預(yù)訓(xùn)練的模型。其他主干請(qǐng)參見(jiàn)DETR github:https://github.com/facebookresearch/detr
detr = torch.hub.load('facebookresearch/detr',
'detr_resnet50',
pretrained=True)
8.2 加載一張圖像
要從web加載圖像,我們使用requests庫(kù):
url = 'https://www.tempetourism.com/wp-content/uploads/Postino-Downtown-Tempe-2.jpg' # Sample imageimage = Image.open(requests.get(url, stream=True).raw) plt.imshow(image)
plt.show()

8.3 設(shè)置目標(biāo)檢測(cè)的Pipeline
為了將圖像輸入到模型中,我們需要將PIL圖像轉(zhuǎn)換為張量,這是通過(guò)使用torchvision的transforms庫(kù)來(lái)完成的。
transform = T.Compose([T.Resize(800),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
上面的變換調(diào)整了圖像的大小,將PIL圖像進(jìn)行轉(zhuǎn)換,并用均值-標(biāo)準(zhǔn)差對(duì)圖像進(jìn)行歸一化。其中[0.485,0.456,0.406]為各顏色通道的均值,[0.229,0.224,0.225]為各顏色通道的標(biāo)準(zhǔn)差。
我們裝載的模型是預(yù)先在COCO Dataset上訓(xùn)練的,有91個(gè)類(lèi),還有一個(gè)表示空類(lèi)(沒(méi)有目標(biāo))的附加類(lèi)。我們用下面的代碼手動(dòng)定義每個(gè)標(biāo)簽:
CLASSES =
['N/A', 'Person', 'Bicycle', 'Car', 'Motorcycle', 'Airplane', 'Bus', 'Train', 'Truck', 'Boat', 'Traffic-Light', 'Fire-Hydrant', 'N/A', 'Stop-Sign', 'Parking Meter', 'Bench', 'Bird', 'Cat', 'Dog', 'Horse', 'Sheep', 'Cow', 'Elephant', 'Bear', 'Zebra', 'Giraffe', 'N/A', 'Backpack', 'Umbrella', 'N/A', 'N/A', 'Handbag', 'Tie', 'Suitcase', 'Frisbee', 'Skis', 'Snowboard', 'Sports-Ball', 'Kite', 'Baseball Bat', 'Baseball Glove', 'Skateboard', 'Surfboard', 'Tennis Racket', 'Bottle', 'N/A', 'Wine Glass', 'Cup', 'Fork', 'Knife', 'Spoon', 'Bowl', 'Banana', 'Apple', 'Sandwich', 'Orange', 'Broccoli', 'Carrot', 'Hot-Dog', 'Pizza', 'Donut', 'Cake', 'Chair', 'Couch', 'Potted Plant', 'Bed', 'N/A', 'Dining Table', 'N/A','N/A', 'Toilet', 'N/A', 'TV', 'Laptop', 'Mouse', 'Remote', 'Keyboard', 'Cell-Phone', 'Microwave', 'Oven', 'Toaster', 'Sink', 'Refrigerator', 'N/A', 'Book', 'Clock', 'Vase', 'Scissors', 'Teddy-Bear', 'Hair-Dryer', 'Toothbrush']
如果我們想輸出不同顏色的邊框,我們可以手動(dòng)定義我們想要的RGB格式的顏色
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933]
]
8.4 格式化輸出
我們還需要重新格式化模型的輸出。給定一個(gè)轉(zhuǎn)換后的圖像,模型將輸出一個(gè)字典,包含100個(gè)預(yù)測(cè)類(lèi)的概率和100個(gè)預(yù)測(cè)邊框。
每個(gè)包圍框的形式為(x, y, w, h),其中(x,y)為包圍框的中心(包圍框是單位正方形[0,1]×[0,1]), w, h為包圍框的寬度和高度。因此,我們需要將邊界框輸出轉(zhuǎn)換為初始和最終坐標(biāo),并重新縮放框以適應(yīng)圖像的實(shí)際大小。
下面的函數(shù)返回邊界框端點(diǎn):
# Get coordinates (x0, y0, x1, y0) from model output (x, y, w, h)def get_box_coords(boxes):
x, y, w, h = boxes.unbind(1)
x0, y0 = (x - 0.5 * w), (y - 0.5 * h)
x1, y1 = (x + 0.5 * w), (y + 0.5 * h)
box = [x0, y0, x1, y1]
return torch.stack(box, dim=1)
我們還需要縮放了框的大小。下面的函數(shù)為我們做了這些:
# Scale box from [0,1]x[0,1] to [0, width]x[0, height]def scale_boxes(output_box, width, height):
box_coords = get_box_coords(output_box)
scale_tensor = torch.Tensor(
[width, height, width, height]).to(
torch.cuda.current_device()) return box_coords * scale_tensor
現(xiàn)在我們需要一個(gè)函數(shù)來(lái)封裝我們的目標(biāo)檢測(cè)pipeline。下面的detect函數(shù)為我們完成了這項(xiàng)工作。
# Object Detection Pipelinedef detect(im, model, transform):
device = torch.cuda.current_device()
width = im.size[0]
height = im.size[1]
# mean-std normalize the input image (batch-size: 1)
img = transform(im).unsqueeze(0)
img = img.to(device)
# demo model only support by default images with aspect ratio between 0.5 and 2 assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, # propagate through the model
outputs = model(img) # keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.85
# convert boxes from [0; 1] to image scales
bboxes_scaled = scale_boxes(outputs['pred_boxes'][0, keep], width, height) return probas[keep], bboxes_scaled
現(xiàn)在,我們需要做的是運(yùn)行以下程序來(lái)獲得我們想要的輸出:
probs, bboxes = detect(image, detr, transform)
8.5 繪制結(jié)果
現(xiàn)在我們有了檢測(cè)到的目標(biāo),我們可以使用一個(gè)簡(jiǎn)單的函數(shù)來(lái)可視化它們。
# Plot Predicted Bounding Boxesdef plot_results(pil_img, prob, boxes,labels=True):
plt.figure(figsize=(16,10))
plt.imshow(pil_img)
ax = plt.gca()
for prob, (x0, y0, x1, y1), color in zip(prob, boxes.tolist(), COLORS * 100): ax.add_patch(plt.Rectangle((x0, y0), x1 - x0, y1 - y0,
fill=False, color=color, linewidth=2))
cl = prob.argmax()
text = f'{CLASSES[cl]}: {prob[cl]:0.2f}'
if labels:
ax.text(x0, y0, text, fontsize=15,
bbox=dict(facecolor=color, alpha=0.75))
plt.axis('off')
plt.show()
現(xiàn)在可以可視化結(jié)果:
plot_results(image, probs, bboxes, labels=True)


英文原文
https://medium.com/swlh/object-detection-with-transformers-437217a3d62e
END
整理不易,點(diǎn)贊三連↓
