完整的文本檢測與識別 | 附源碼
共 33226字,需瀏覽 67分鐘
·
2024-07-16 10:12
點擊上方“小白學(xué)視覺”,選擇加"星標(biāo)"或“置頂”
重磅干貨,第一時間送達(dá)
另外,你還記得每家店鋪都有獨特的名字書寫方式嗎?像Gucci、Sears、Pantaloons和Lifestyle這樣的知名品牌在其商標(biāo)中使用了曲線或圓形字體。雖然這一切吸引了顧客,但對于執(zhí)行文本檢測和識別的深度學(xué)習(xí)(DL)模型來說,它確實提出了挑戰(zhàn)。
當(dāng)你讀取橫幅上的文字時,你會怎么做?你的眼睛首先會檢測到文本的存在,找出每個字符的位置,然后識別這些字符。這正是一個DL模型需要做的!最近,OCR在深度學(xué)習(xí)中成為熱門話題,其中每個新架構(gòu)都在努力超越其他架構(gòu)。
流行的基于深度學(xué)習(xí)的OCR模塊Tesseract在結(jié)構(gòu)化文本(如文件)上表現(xiàn)出色,但在花哨字體的曲線、不規(guī)則形狀的文本方面卻表現(xiàn)不佳。幸運的是,我們有Clova AI提供的這些出色的網(wǎng)絡(luò),它們在真實世界中出現(xiàn)的各種文本外觀方面勝過了Tesseract。在本博客中,我們將簡要討論這些架構(gòu)并學(xué)習(xí)如何將它們整合起來。
使用CRAFT進(jìn)行文本檢測
場景文本檢測是在復(fù)雜背景中檢測文本區(qū)域并用邊界框標(biāo)記它們的任務(wù)。CRAFT是一項2019年提出的主要目標(biāo)是定位單個字符區(qū)域并將檢測到的字符鏈接到文本實例的全稱:Character-Region Awareness For Text detection。
CRAFT采用了基于VGG-16的全卷積網(wǎng)絡(luò)架構(gòu)。簡單來說,VGG16本質(zhì)上是特征提取架構(gòu),用于將網(wǎng)絡(luò)的輸入編碼成某種特征表示。CRAFT網(wǎng)絡(luò)的解碼段類似于UNet。它具有聚合低級特征的跳躍連接。CRAFT為每個字符預(yù)測兩個分?jǐn)?shù):
區(qū)域分?jǐn)?shù):顧名思義,它給出了字符的區(qū)域。它定位字符。
親和力分?jǐn)?shù):'親和力'是指物質(zhì)傾向于與另一種物質(zhì)結(jié)合的程度。
因此,親和力分?jǐn)?shù)將字符合并為單個實例(一個詞)。CRAFT生成兩個地圖作為輸出:區(qū)域級地圖和親和力地圖。讓我們通過示例來理解它們的含義:
輸入圖像
存在字符的區(qū)域在區(qū)域地圖中標(biāo)記出來:
區(qū)域地圖
親和力地圖以圖形方式表示相關(guān)字符。紅色表示字符具有較高的親和力,必須合并為一個詞:
親和力地圖
最后,將親和力分?jǐn)?shù)和區(qū)域分?jǐn)?shù)組合起來,給出每個單詞的邊界框。坐標(biāo)的順序是:(左上)、(右上)、(右下)、(左下),其中每個坐標(biāo)都是一個(x,y)對。
為什么不按照四點格式?
看下面的圖片:你能在僅有4個值的情況下定位“LOVE”嗎?
CRAFT是多語言的,這意味著它可以檢測任何腳本中的文本。
文本識別:四階段場景文本識別框架
2019年,Clova AI發(fā)表了一篇關(guān)于現(xiàn)有場景文本識別(STR)數(shù)據(jù)集的不一致性,并提出了一個大多數(shù)現(xiàn)有STR模型都適用的統(tǒng)一框架的研究論文。
讓我們討論這四個階段:
轉(zhuǎn)換:記住我們正在處理的是景觀文本,它是任意形狀和曲線的。如果我們直接進(jìn)行特征提取,那么它需要學(xué)習(xí)輸入文本的幾何形狀,這對于特征提取模塊來說是額外的工作。因此,STR網(wǎng)絡(luò)應(yīng)用了薄板樣條(TPS)變換,并將輸入文本規(guī)范化為矩形形狀。
特征提?。?/span>將變換后的圖像映射到與字符識別相關(guān)的一組特征上。字體、顏色、大小和背景都被丟棄了。作者對不同的骨干網(wǎng)絡(luò)進(jìn)行了實驗,包括ResNet、VGG和RCNN。
序列建模:如果我寫下'ba_',你很可能猜到填在空格處的字母可能是'd'、'g'、't',而不是'u'、'p'。我們?nèi)绾谓叹W(wǎng)絡(luò)捕捉上下文信息?使用BiLSTMs!但是,BiLSTMs會占用內(nèi)存,因此用戶可以根據(jù)需要選擇或取消這個階段。
預(yù)測:這個階段從圖像的已識別特征中估計輸出字符序列。
作者進(jìn)行了幾個實驗。他們?yōu)槊總€階段選擇了不同的網(wǎng)絡(luò)。準(zhǔn)確性總結(jié)在下表中:
代碼
CRAFT預(yù)測每個單詞的邊界框。四階段STR將單個單詞(作為圖像)作為輸入,并預(yù)測字母。如果你正在處理單個字的圖像(如CUTE80),使用這些DL模塊的OCR將會很輕松。
步驟1:安裝要求
步驟2:克隆代碼庫
步驟3:修改以返回檢測框分?jǐn)?shù)
CRAFT返回高于一定分?jǐn)?shù)閾值的邊界框。如果你想看到每個邊界框的分?jǐn)?shù)值,我們需要對原始庫進(jìn)行一些更改。打開克隆的CRAFT Repository中的craft_utils.py文件。你需要將第83行和第239行更改為如下所示。
"""Modify to Return Scores of Detection Boxes""""""Copyright (c) 2019-present NAVER Corp.MIT License"""# -*- coding: utf-8 -*-import numpy as npimport cv2import math""" auxilary functions """# unwarp corodinatesdef warpCoord(Minv, pt):out = np.matmul(Minv, (pt[0], pt[1], 1))return np.array([out[0]/out[2], out[1]/out[2]])""" end of auxilary functions """def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):# prepare datalinkmap = linkmap.copy()textmap = textmap.copy()img_h, img_w = textmap.shape""" labeling method """ret, text_score = cv2.threshold(textmap, low_text, 1, 0)ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)text_score_comb = np.clip(text_score + link_score, 0, 1)nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)det = []det_scores = []mapper = []for k in range(1,nLabels):# size filteringsize = stats[k, cv2.CC_STAT_AREA]if size < 10: continue# thresholdingif np.max(textmap[labels==k]) < text_threshold: continue# make segmentation mapsegmap = np.zeros(textmap.shape, dtype=np.uint8)segmap[labels==k] = 255segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link areax, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1# boundary checkif sx < 0 : sx = 0if sy < 0 : sy = 0if ex >= img_w: ex = img_wif ey >= img_h: ey = img_hkernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)# make boxnp_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)rectangle = cv2.minAreaRect(np_contours)box = cv2.boxPoints(rectangle)# align diamond-shapew, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])box_ratio = max(w, h) / (min(w, h) + 1e-5)if abs(1 - box_ratio) <= 0.1:l, r = min(np_contours[:,0]), max(np_contours[:,0])t, b = min(np_contours[:,1]), max(np_contours[:,1])box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)# make clock-wise orderstartidx = box.sum(axis=1).argmin()box = np.roll(box, 4-startidx, 0)box = np.array(box)det.append(box)mapper.append(k)det_scores.append(np.max(textmap[labels==k]))return det, labels, mapper, det_scoresdef getPoly_core(boxes, labels, mapper, linkmap):# configsnum_cp = 5max_len_ratio = 0.7expand_ratio = 1.45max_r = 2.0step_r = 0.2polys = []for k, box in enumerate(boxes):# size filter for small instancew, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)if w < 10 or h < 10:polys.append(None); continue# warp imagetar = np.float32([[0,0],[w,0],[w,h],[0,h]])M = cv2.getPerspectiveTransform(box, tar)word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)try:Minv = np.linalg.inv(M)except:polys.append(None); continue# binarization for selected labelcur_label = mapper[k]word_label[word_label != cur_label] = 0word_label[word_label > 0] = 1""" Polygon generation """# find top/bottom contourscp = []max_len = -1for i in range(w):region = np.where(word_label[:,i] != 0)[0]if len(region) < 2 : continuecp.append((i, region[0], region[-1]))length = region[-1] - region[0] + 1if length > max_len: max_len = length# pass if max_len is similar to hif h * max_len_ratio < max_len:polys.append(None); continue# get pivot points with fixed lengthtot_seg = num_cp * 2 + 1seg_w = w / tot_seg # segment widthpp = [None] * num_cp # init pivot pointscp_section = [[0, 0]] * tot_segseg_height = [0] * num_cpseg_num = 0num_sec = 0prev_h = -1for i in range(0,len(cp)):(x, sy, ey) = cp[i]if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:# average previous segmentif num_sec == 0: breakcp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]num_sec = 0# reset variablesseg_num += 1prev_h = -1# accumulate center pointscy = (sy + ey) * 0.5cur_h = ey - sy + 1cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]num_sec += 1if seg_num % 2 == 0: continue # No polygon areaif prev_h < cur_h:pp[int((seg_num - 1)/2)] = (x, cy)seg_height[int((seg_num - 1)/2)] = cur_hprev_h = cur_h# processing last segmentif num_sec != 0:cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]# pass if num of pivots is not sufficient or segment widh is smaller than character heightif None in pp or seg_w < np.max(seg_height) * 0.25:polys.append(None); continue# calc median maximum of pivot pointshalf_char_h = np.median(seg_height) * expand_ratio / 2# calc gradiant and apply to make horizontal pivotsnew_pp = []for i, (x, cy) in enumerate(pp):dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]if dx == 0: # gradient if zeronew_pp.append([x, cy - half_char_h, x, cy + half_char_h])continuerad = - math.atan2(dy, dx)c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)new_pp.append([x - s, cy - c, x + s, cy + c])# get edge points to cover character heatmapsisSppFound, isEppFound = False, Falsegrad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])for r in np.arange(0.5, max_r, step_r):dx = 2 * half_char_h * rif not isSppFound:line_img = np.zeros(word_label.shape, dtype=np.uint8)dy = grad_s * dxp = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:spp = pisSppFound = Trueif not isEppFound:line_img = np.zeros(word_label.shape, dtype=np.uint8)dy = grad_e * dxp = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:epp = pisEppFound = Trueif isSppFound and isEppFound:break# pass if boundary of polygon is not foundif not (isSppFound and isEppFound):polys.append(None); continue# make final polygonpoly = []poly.append(warpCoord(Minv, (spp[0], spp[1])))for p in new_pp:poly.append(warpCoord(Minv, (p[0], p[1])))poly.append(warpCoord(Minv, (epp[0], epp[1])))poly.append(warpCoord(Minv, (epp[2], epp[3])))for p in reversed(new_pp):poly.append(warpCoord(Minv, (p[2], p[3])))poly.append(warpCoord(Minv, (spp[2], spp[3])))# add to final resultpolys.append(np.array(poly))return polysdef getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):boxes, labels, mapper, det_scores = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)if poly:polys = getPoly_core(boxes, labels, mapper, linkmap)else:polys = [None] * len(boxes)return boxes, polys, det_scoresdef adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):if len(polys) > 0:polys = np.array(polys)for k in range(len(polys)):if polys[k] is not None:polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)return polys
步驟4:從CRAFT中刪除參數(shù)解析器
打開test.py并修改如下所示。我們刪除了參數(shù)解析器。
"""Modify to Remove Argument Parser""""""Copyright (c) 2019-present NAVER Corp.MIT License"""# -*- coding: utf-8 -*-import sysimport osimport timeimport argparseimport torchimport torch.nn as nnimport torch.backends.cudnn as cudnnfrom torch.autograd import Variablefrom PIL import Imageimport cv2from skimage import ioimport numpy as npimport craft_utilsimport imgprocimport file_utilsimport jsonimport zipfilefrom craft import CRAFTfrom collections import OrderedDictdef copyStateDict(state_dict):if list(state_dict.keys())[0].startswith("module"):start_idx = 1else:start_idx = 0new_state_dict = OrderedDict()for k, v in state_dict.items():name = ".".join(k.split(".")[start_idx:])new_state_dict[name] = vreturn new_state_dictdef test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, args, refine_net=None):t0 = time.time()# resizeimg_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio)ratio_h = ratio_w = 1 / target_ratio# preprocessingx = imgproc.normalizeMeanVariance(img_resized)x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]if cuda:x = x.cuda()# forward passwith torch.no_grad():y, feature = net(x)# make score and link mapscore_text = y[0,:,:,0].cpu().data.numpy()score_link = y[0,:,:,1].cpu().data.numpy()# refine linkif refine_net is not None:with torch.no_grad():y_refiner = refine_net(y, feature)score_link = y_refiner[0,:,:,0].cpu().data.numpy()t0 = time.time() - t0t1 = time.time()# Post-processingboxes, polys, det_scores = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)# coordinate adjustmentboxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)for k in range(len(polys)):if polys[k] is None: polys[k] = boxes[k]t1 = time.time() - t1# render results (optional)render_img = score_text.copy()render_img = np.hstack((render_img, score_link))ret_score_text = imgproc.cvt2HeatmapImg(render_img)if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))return boxes, polys, ret_score_text, det_scores
步驟5:編寫一個單獨的腳本,將圖像名稱和檢測框坐標(biāo)保存到CSV文件中
這將幫助我們裁剪需要作為四階段STR輸入的單詞。它還幫助我們將所有與邊界框和文本相關(guān)的信息保存在一個地方。創(chuàng)建一個新文件(我將其命名為pipeline.py)并添加以下代碼。
import sysimport osimport timeimport argparseimport torchimport torch.nn as nnimport torch.backends.cudnn as cudnnfrom torch.autograd import Variablefrom PIL import Imageimport cv2from skimage import ioimport numpy as npimport craft_utilsimport testimport imgprocimport file_utilsimport jsonimport zipfileimport pandas as pdfrom craft import CRAFTfrom collections import OrderedDictfrom google.colab.patches import cv2_imshowdef str2bool(v):return v.lower() in ("yes", "y", "true", "t", "1")#CRAFTparser = argparse.ArgumentParser(description='CRAFT Text Detection')parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')args = parser.parse_args()""" For test images in a folder """image_list, _, _ = file_utils.get_files(args.test_folder)image_names = []image_paths = []#CUSTOMISE STARTstart = args.test_folderfor num in range(len(image_list)):image_names.append(os.path.relpath(image_list[num], start))result_folder = './Results'if not os.path.isdir(result_folder):os.mkdir(result_folder)if __name__ == '__main__':data=pd.DataFrame(columns=['image_name', 'word_bboxes', 'pred_words', 'align_text'])data['image_name'] = image_names# load netnet = CRAFT() # initializeprint('Loading weights from checkpoint (' + args.trained_model + ')')if args.cuda:net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))else:net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))if args.cuda:net = net.cuda()net = torch.nn.DataParallel(net)cudnn.benchmark = Falsenet.eval()# LinkRefinerrefine_net = Noneif args.refine:from refinenet import RefineNetrefine_net = RefineNet()print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')if args.cuda:refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))refine_net = refine_net.cuda()refine_net = torch.nn.DataParallel(refine_net)else:refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))refine_net.eval()args.poly = Truet = time.time()# load datafor k, image_path in enumerate(image_list):print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')image = imgproc.loadImage(image_path)bboxes, polys, score_text, det_scores = test.test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, args, refine_net)bbox_score={}for box_num in range(len(bboxes)):key = str (det_scores[box_num])item = bboxes[box_num]bbox_score[key]=itemdata['word_bboxes'][k]=bbox_score# save score textfilename, file_ext = os.path.splitext(os.path.basename(image_path))mask_file = result_folder + "/res_" + filename + '_mask.jpg'cv2.imwrite(mask_file, score_text)file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)data.to_csv('/content/Pipeline/data.csv', sep = ',', na_rep='Unknown')print("elapsed time : {}s".format(time.time() - t))
pandas DataFrame(變量data)在單獨的列中存儲圖像名稱和其中包含的單詞的邊界框。我們?nèi)サ袅藞D像的完整路徑,只保留了圖像,以避免笨拙。你當(dāng)然可以根據(jù)自己的需要進(jìn)行定制。現(xiàn)在可以運行腳本了:
在這個階段,CSV看起來像這樣。對于每個檢測,我們都存儲了一個包含分?jǐn)?shù):坐標(biāo)的Python字典。
步驟6:裁剪單詞
現(xiàn)在我們有了每個框的坐標(biāo)和分?jǐn)?shù)。我們可以設(shè)置一個閾值,裁剪我們希望識別字符的單詞。創(chuàng)建一個新腳本crop_images.py。請記住,在提到的地方添加你的路徑。裁剪的單詞保存在'dir'文件夾中。我們?yōu)槊總€圖像創(chuàng)建一個文件夾,并以以下格式保存從中裁剪的單詞:<父圖像>_<由下劃線分隔的8個坐標(biāo)> 這樣做可以幫助你跟蹤每個裁剪單詞來自哪個圖像。
import osimport numpy as npimport cv2import pandas as pdfrom google.colab.patches import cv2_imshowdef crop(pts, image):"""Takes inputs as 8 pointsand Returns cropped, masked image with a white background"""rect = cv2.boundingRect(pts)x,y,w,h = rectcropped = image[y:y+h, x:x+w].copy()pts = pts - pts.min(axis=0)mask = np.zeros(cropped.shape[:2], np.uint8)cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)dst = cv2.bitwise_and(cropped, cropped, mask=mask)bg = np.ones_like(cropped, np.uint8)*255cv2.bitwise_not(bg,bg, mask=mask)dst2 = bg + dstreturn dst2def generate_words(image_name, score_bbox, image):num_bboxes = len(score_bbox)for num in range(num_bboxes):bbox_coords = score_bbox[num].split(':')[-1].split(',\n')if bbox_coords!=['{}']:l_t = float(bbox_coords[0].strip(' array([').strip(']').split(',')[0])t_l = float(bbox_coords[0].strip(' array([').strip(']').split(',')[1])r_t = float(bbox_coords[1].strip(' [').strip(']').split(',')[0])t_r = float(bbox_coords[1].strip(' [').strip(']').split(',')[1])r_b = float(bbox_coords[2].strip(' [').strip(']').split(',')[0])b_r = float(bbox_coords[2].strip(' [').strip(']').split(',')[1])l_b = float(bbox_coords[3].strip(' [').strip(']').split(',')[0])b_l = float(bbox_coords[3].strip(' [').strip(']').split(',')[1].strip(']'))pts = np.array([[int(l_t), int(t_l)], [int(r_t) ,int(t_r)], [int(r_b) , int(b_r)], [int(l_b), int(b_l)]])if np.all(pts) > 0:word = crop(pts, image)folder = '/'.join( image_name.split('/')[:-1])dir = '/content/Pipeline/Crop Words/'if os.path.isdir(os.path.join(dir + folder)) == False :os.makedirs(os.path.join(dir + folder))try:file_name = os.path.join(dir + image_name)cv2.imwrite(file_name+'_{}_{}_{}_{}_{}_{}_{}_{}.jpg'.format(l_t, t_l, r_t ,t_r, r_b , b_r ,l_b, b_l), word)print('Image saved to '+file_name+'_{}_{}_{}_{}_{}_{}_{}_{}.jpg'.format(l_t, t_l, r_t ,t_r, r_b , b_r ,l_b, b_l))except:continuedata=pd.read_csv('PATH TO CSV')start = PATH TO TEST IMAGESfor image_num in range(data.shape[0]):image = cv2.imread(os.path.join(start, data['image_name'][image_num]))image_name = data['image_name'][image_num].strip('.jpg')score_bbox = data['word_bboxes'][image_num].split('),')generate_words(image_name, score_bbox, image)
運行腳本:
步驟6:識別(最后!)
現(xiàn)在你可以在裁剪的單詞上盲目運行識別模塊了。但如果你想讓事情更有條理,修改如下所示。我們在每個圖像文件夾中創(chuàng)建一個.txt文件,并將識別的單詞與裁剪圖像的名稱一起保存。除此之外,預(yù)測的單詞也保存在我們維護的CSV中。
import stringimport argparseimport torchimport torch.backends.cudnn as cudnnimport torch.utils.dataimport torch.nn.functional as Ffrom utils import CTCLabelConverter, AttnLabelConverterfrom dataset import RawDataset, AlignCollatefrom model import Modeldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')import pandas as pdimport osdef demo(opt):"""Open csv file wherein you are going to write the Predicted Words"""data = pd.read_csv('/content/Pipeline/data.csv')""" model configuration """if 'CTC' in opt.Prediction:converter = CTCLabelConverter(opt.character)else:converter = AttnLabelConverter(opt.character)opt.num_class = len(converter.character)if opt.rgb:opt.input_channel = 3model = Model(opt)print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,opt.SequenceModeling, opt.Prediction)model = torch.nn.DataParallel(model).to(device)# load modelprint('loading pretrained model from %s' % opt.saved_model)model.load_state_dict(torch.load(opt.saved_model, map_location=device))# prepare data. two demo images from https://github.com/bgshih/crnn#run-demoAlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDatasetdemo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size,shuffle=False,num_workers=int(opt.workers),collate_fn=AlignCollate_demo, pin_memory=True)# predictmodel.eval()with torch.no_grad():for image_tensors, image_path_list in demo_loader:batch_size = image_tensors.size(0)image = image_tensors.to(device)# For max length predictionlength_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)if 'CTC' in opt.Prediction:preds = model(image, text_for_pred)# Select max probabilty (greedy decoding) then decode index to characterpreds_size = torch.IntTensor([preds.size(1)] * batch_size)_, preds_index = preds.max(2)# preds_index = preds_index.view(-1)preds_str = converter.decode(preds_index.data, preds_size.data)else:preds = model(image, text_for_pred, is_train=False)# select max probabilty (greedy decoding) then decode index to character_, preds_index = preds.max(2)preds_str = converter.decode(preds_index, length_for_pred)dashed_line = '-' * 80head = f'{"image_path":25s}\t {"predicted_labels":25s}\t confidence score'print(f'{dashed_line}\n{head}\n{dashed_line}')# log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')preds_prob = F.softmax(preds, dim=2)preds_max_prob, _ = preds_prob.max(dim=2)for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):start = PATH TO CROPPED WORDSpath = os.path.relpath(img_name, start)folder = os.path.dirname(path)image_name=os.path.basename(path)file_name='_'.join(image_name.split('_')[:-8])txt_file=os.path.join(start, folder, file_name)log = open(f'{txt_file}_log_demo_result_vgg.txt', 'a')if 'Attn' in opt.Prediction:pred_EOS = pred.find('[s]')pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])pred_max_prob = pred_max_prob[:pred_EOS]# calculate confidence score (= multiply of pred_max_prob)confidence_score = pred_max_prob.cumprod(dim=0)[-1]print(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}')log.write(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}\n')log.close()if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--image_folder', required=True, help='path to image_folder which contains text images')parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)parser.add_argument('--batch_size', type=int, default=192, help='input batch size')parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")""" Data processing """parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')parser.add_argument('--rgb', action='store_true', help='use rgb input')parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')""" Model Architecture """parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')parser.add_argument('--output_channel', type=int, default=512,help='the number of output channel of Feature extractor')parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')opt = parser.parse_args()""" vocab / character number configuration """if opt.sensitive:opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).cudnn.benchmark = Truecudnn.deterministic = Trueopt.num_gpu = torch.cuda.device_count()# print (opt.image_folder)# pred_words=demo(opt)demo(opt)
從Clova AI STR Github Repository下載權(quán)重后,你可以運行以下命令:
我們選擇了這種網(wǎng)絡(luò)組合,因為它們的準(zhǔn)確性很高。現(xiàn)在CSV看起來是這樣的。pred_words有檢測框坐標(biāo)和預(yù)測的單詞,用冒號分隔。
結(jié)論
我們已經(jīng)集成了兩個準(zhǔn)確的模型,創(chuàng)建了一個單一的檢測和識別模塊?,F(xiàn)在你有了預(yù)測的單詞和它們的邊界框在一個單獨的列中,你可以以任何你想要的方式對齊文本!
下載1:OpenCV-Contrib擴展模塊中文版教程
在「小白學(xué)視覺」公眾號后臺回復(fù):擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標(biāo)跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。
下載2:Python視覺實戰(zhàn)項目52講
在「小白學(xué)視覺」公眾號后臺回復(fù):Python視覺實戰(zhàn)項目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學(xué)校計算機視覺。
下載3:OpenCV實戰(zhàn)項目20講
在「小白學(xué)視覺」公眾號后臺回復(fù):OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學(xué)習(xí)進(jìn)階。
交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學(xué)影像、GAN、算法競賽等微信群(以后會逐漸細(xì)分),請掃描下面微信號加群,備注:”昵稱+學(xué)校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進(jìn)入相關(guān)微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~
