當前位置: 華文星空 > 寵物

利用ResNet

2021-07-19寵物

一、前言

本文基於殘留誤差網絡模型,透過對ResNet-50模型進行微調,對不同狗狗品種數據集進行鑒定。

Dog Breed Identification數據集包含20579張不同size的彩色圖片,共分為120類犬種。其中訓練集包含10222張圖片,測試集包含10357張圖片。犬種數據集樣本圖如下所示。

二、基於Fine Tuning構建ResNet-50模型

隨著模型深度的提升,出現網絡出現退化問題,因此殘留誤差網絡應運而生。透過微調已經構建好的模型,能夠在相似的數據集運用,不再需要重新訓練模型。本文模型構建分為四個部份:數據讀取及預處理、構建ResNet-50模型以及模型微調、定義模型超參數以及評估方法、參數最佳化。

1、數據讀取及預處理

本文采用pandas對數據進行預處理,以及透過GPU對運算進行提速。

import os , torch , torchvision import torch.nn as nn import torch.nn.functional as F import numpy as np import pandas as pd from torch.utils.data import DataLoader , Dataset from torchvision import datasets , models , transforms from PIL import Image from sklearn.model_selection import StratifiedShuffleSplit device = torch . device ( "cuda:0" if torch . cuda . is_available () else "cpu" )

透過pandas讀取csv檔,顯示前5條數據,發現原數據只存在兩列數據,id對應圖片名,breed對應犬種名稱,總共有120種。

data_root = 'data' all_labels_df = pd . read_csv ( os . path . join ( data_root , 'labels.csv' )) all_labels_df . head ()

根據犬種名稱,將其轉化為id的形式進行一一對應,並在數據集中增加新列,表示犬種類別的id標簽。

breeds = all_labels_df . breed . unique () breed2idx = dict (( breed , idx ) for idx , breed in enumerate ( breeds )) idx2breed = dict (( idx , breed ) for idx , breed in enumerate ( breeds )) all_labels_df [ 'label_idx' ] = all_

< style data-emotion-css="19xugg7"> .css-19xugg7{position:absolute;width:100%;bottom:0;background-image:linear-gradient(to bottom,transparent,#ffffff 50px);} < style data-emotion-css="12cv0pi"> .css-12cv0pi{box-sizing:border-box;margin:0;min-width:0;height:100px;-webkit-box-pack:center;-webkit-justify-content:center;-ms-flex-pack:center;justify-content:center;display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;position:absolute;width:100%;bottom:0;background-image:linear-gradient(to bottom,transparent,#ffffff 50px);}
< style data-emotion-css="1pr2waf"> .css-1pr2waf{font-size:15px;color:#09408e;}
釋出於 2021-07-19 19:40
< style data-emotion-css="ch8ocw"> .css-ch8ocw{position:relative;display:inline-block;height:30px;padding:0 12px;font-size:14px;line-height:30px;color:#1772F6;vertical-align:top;border-radius:100px;background:rgba(23,114,246,0.1);}.css-ch8ocw:hover{background-color:rgba(23,114,246,0.15);}
< style data-emotion-css="1xlfegr"> .css-1xlfegr{background:transparent;box-shadow:none;} < style data-emotion-css="1gomreu"> .css-1gomreu{position:relative;display:inline-block;}
殘留誤差網絡