TensorFlow實現(xiàn)對花朵數(shù)據(jù)集的圖片分類
點擊下方卡片,關(guān)注“新機器視覺”公眾號
重磅干貨,第一時間送達
轉(zhuǎn)載自:古月居
轉(zhuǎn)載自:古月居
前言
利用TensorFlow實現(xiàn)對花朵數(shù)據(jù)集的圖片分類
提示:以下是本篇文章正文內(nèi)容,下面案例可供參考
一、數(shù)據(jù)集
數(shù)據(jù)集是五個分別存放著對應(yīng)類別花朵圖片的五個文件夾,包括daisy(雛菊)633張;dandelion(蒲公英)898張,rose(玫瑰)641張,sunflower(向日葵)699張,tulips(郁金香)799張。
二、代碼
1、下載數(shù)據(jù)集
import tensorflow as tfAUTOTUNE = tf.data.experimental.AUTOTUNEimport pathlibdata_root_orig = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', fname='flower_photos', untar=True)data_root = pathlib.Path(data_root_orig)print(data_root)for item in data_root.iterdir(): ?? ?print(item)
打印下載后的文件路徑和文件成員:
output:
C:\Users\Administrator.keras\datasets\flower_photos
C:\Users\Administrator.keras\datasets\flower_photos\daisy
C:\Users\Administrator.keras\datasets\flower_photos\dandelion
C:\Users\Administrator.keras\datasets\flower_photos\LICENSE.txt
C:\Users\Administrator.keras\datasets\flower_photos\roses
C:\Users\Administrator.keras\datasets\flower_photos\sunflowers
C:\Users\Administrator.keras\datasets\flower_photos\tulips
2、統(tǒng)計并觀察數(shù)據(jù)
#獲取五個文件夾的名字label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())label_names
output:[‘daisy’, ‘dandelion’, ‘roses’, ‘sunflowers’, ‘tulips’]
#將文件夾的名字(即花的分類)標(biāo)上序號label_to_index = dict((name, index) for index, name in enumerate(label_names))label_to_index
{‘daisy’: 0, ‘dandelion’: 1, ‘roses’: 2, ‘sunflowers’: 3, ‘tulips’: 4}
#獲取所以圖片的標(biāo)簽(0,1,2,3,4)import randomall_image_paths = list(data_root.glob('*/*'))all_image_paths = [str(path) for path in all_image_paths]random.shuffle(all_image_paths)image_count = len(all_image_paths)image_countall_image_labels = [label_to_index[pathlib.Path(path).parent.name]for path in all_image_paths]print("First 10 labels indices: ", all_image_labels[:10])
3670
First 10 labels indices: [0, 2, 3, 4, 2, 1, 4, 1, 4, 0]
下面我們先來觀察一張圖片
#觀察第一張圖片img_path = all_image_paths[0]img_path
‘C:\Users\Administrator\.keras\datasets\flower_photos\daisy\11124324295_503f3a0804.jpg’
#讀取原圖img_raw = tf.io.read_file(img_path)#轉(zhuǎn)換成TensorFlow可以使用的tensor類型img_tensor = tf.image.decode_image(img_raw)print(img_tensor.shape)print(img_tensor.dtype)
(309, 500, 3)
#對圖片按要求進行轉(zhuǎn)換,這里將size規(guī)定到【192,192】;值域映射到【0,1】img_final = tf.image.resize(img_tensor, [192, 192])img_final = img_final/255.0print(img_final.shape)print(img_final.numpy().min())print(img_final.numpy().max())
(192, 192, 3)
0.0
0.99984366
定義預(yù)處理和加載函數(shù)
def preprocess_image(image): ?? ?image = tf.image.decode_jpeg(image, channels=3) ?? ?image = tf.image.resize(image, [192, 192]) ?? ?image /= 255.0 ?? ?return imagedef load_and_preprocess_image(path): ?? ?image = tf.io.read_file(path) ?? ?return preprocess_image(image)
導(dǎo)入matpoltlib進行畫圖(導(dǎo)入失敗的解決方案見我的另一篇博文)
import matplotlib.pyplot as pltimage_path = all_image_paths[0]label = all_image_labels[0]print (load_and_preprocess_image(img_path))plt.imshow(load_and_preprocess_image(img_path))plt.grid(False)
3、構(gòu)建數(shù)據(jù)集

使用tf.data.Dataset來構(gòu)建規(guī)范的數(shù)據(jù)集
#“from_tensor_slices ”方法使用張量的切片元素構(gòu)建圖片路徑的數(shù)據(jù)集path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)#同理,構(gòu)建標(biāo)簽數(shù)據(jù)集,并用tf.cast轉(zhuǎn)換成int64數(shù)據(jù)類型label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))#根據(jù)路徑獲取圖片,并經(jīng)過加載和預(yù)處理得到圖片數(shù)據(jù)集image_ds = path_ds.map(load_and_preprocess_image )image_ds
將image_ds和label_ds打包成新的數(shù)據(jù)集
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))print(image_label_ds)
畫圖查看我們構(gòu)造完成的新數(shù)據(jù)
plt.figure(figsize=(8,8))for n,image_label in enumerate(image_label_ds.take(4)):? ?plt.subplot(2,2,n+1)? ?plt.imshow(image_label[0])? ?plt.grid(False)

4、遷移學(xué)習(xí)進行分類
接下來用創(chuàng)建的數(shù)據(jù)集訓(xùn)練一個分類模型,簡單起見,直接用tf.keras.applications包中訓(xùn)練好的模型,并將其遷移到我們的圖片分類問題上來。這里使用的模型是MobileNetV2模型
#遷移MobileNetV2模型,并且不加載頂層base_model=tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False,weights='imagenet',input_shape=(192,192,3))inputs=tf.keras.layers.Input(shape=(192,192,3))#模型可視化1,使用model.summary()方法base_model.summary()
接下來,我們打亂一下數(shù)據(jù)集,并定義好訓(xùn)練過程中每個批次(Batch)數(shù)據(jù)的大小
#使用shuffle方法打亂數(shù)據(jù)集image_count = len(all_image_paths)ds = image_label_ds.shuffle(buffer_size = image_count)#讓數(shù)據(jù)集重復(fù)多次ds = ds.repeat()#設(shè)置每個批次的大小BATCH_SIZE = 32ds = ds.batch(BATCH_SIZE)#通過prefetch方法讓模型的訓(xùn)練和每個批次數(shù)據(jù)的加載并行ds = ds.prefetch(buffer_size = AUTOTUNE)
然后,針對MobileNetV2改變一樣數(shù)據(jù)集的取值范圍,因為MobileNetV2接受輸入的數(shù)據(jù)值域是【-1,1】,而我們之前的預(yù)處理函數(shù)將圖片的像素值映射到【0,1】
def change_range(image,label):? ?return 2*image-1,labelkeras_ds = ds.map(change_range)
接下來定義模型,由于預(yù)訓(xùn)練好的MobileNetV2返回的數(shù)據(jù)維度是(32,6,6,128),其中32是一個批次Batch的大小,“6,6”是輸出的特征的大小為6*6,1280代表該層使用的1280個卷積核。為了使用花朵分類問題,需要做一下調(diào)整
model = tf.keras.Sequential([? ?base_model,? ?tf.keras.layers.GlobalAveragePooling2D(),? ?tf.keras.layers.Dense(len(label_names),activation="softmax")? ?])
如上代碼,我們用Sequentail建立我們的網(wǎng)絡(luò)結(jié)構(gòu),base_model是遷移過來的模型,我們添加了全局評價池化層GlobalAveragePooling,經(jīng)過此操作6*6的特征被降維,變?yōu)椋?2,1280)。
最后,由于該分類問題有五個結(jié)果,我們增加一個全連接層(Dense)將維度變?yōu)椋?2,5)。
最后,編譯一下模型,同時制定使用的優(yōu)化器,損失函數(shù)和評價標(biāo)準(zhǔn)
model.compile(optimizer = tf.keras.optimizers.Adam(),? ? ? ? ? ? loss='sparse_categorical_crossentropy',? ? ? ? ? ? metrics=['accuracy'])model.summary()

使用model.fit訓(xùn)練模型,epochs是訓(xùn)練的回合數(shù),step_per_epoch代表每個回合要去多少個批次數(shù)據(jù)。通常等于我們數(shù)據(jù)集大小除以批次大小后取證(3670/32≈10)
model.fit(ds,epochs=10,steps_per_epoch=100)
雖然沒有跑完整個代碼,但是已經(jīng)能看出來準(zhǔn)確度達到一個很高的程度,并在逐步上升。
本文僅做學(xué)術(shù)分享,如有侵權(quán),請聯(lián)系刪文。
