当前位置: 华文星空 > 宠物

利用ResNet

2021-07-19宠物

一、前言

本文基于残差网络模型,通过对ResNet-50模型进行微调,对不同狗狗品种数据集进行鉴定。

Dog Breed Identification数据集包含20579张不同size的彩色图片,共分为120类犬种。其中训练集包含10222张图片,测试集包含10357张图片。犬种数据集样本图如下所示。

二、基于Fine Tuning构建ResNet-50模型

随着模型深度的提升,出现网络出现退化问题,因此残差网络应运而生。通过微调已经构建好的模型,能够在相似的数据集运用,不再需要重新训练模型。本文模型构建分为四个部分:数据读取及预处理、构建ResNet-50模型以及模型微调、定义模型超参数以及评估方法、参数优化。

1、数据读取及预处理

本文采用pandas对数据进行预处理,以及通过GPU对运算进行提速。

import os , torch , torchvision import torch.nn as nn import torch.nn.functional as F import numpy as np import pandas as pd from torch.utils.data import DataLoader , Dataset from torchvision import datasets , models , transforms from PIL import Image from sklearn.model_selection import StratifiedShuffleSplit device = torch . device ( "cuda:0" if torch . cuda . is_available () else "cpu" )

通过pandas读取csv文件,显示前5条数据,发现原数据只存在两列数据,id对应图片名,breed对应犬种名称,总共有120种。

data_root = 'data' all_labels_df = pd . read_csv ( os . path . join ( data_root , 'labels.csv' )) all_labels_df . head ()

根据犬种名称,将其转化为id的形式进行一一对应,并在数据集中增加新列,表示犬种类别的id标签。

breeds = all_labels_df . breed . unique () breed2idx = dict (( breed , idx ) for idx , breed in enumerate ( breeds )) idx2breed = dict (( idx , breed ) for idx , breed in enumerate ( breeds )) all_labels_df [ 'label_idx' ] = all_

< style data-emotion-css="19xugg7"> .css-19xugg7{position:absolute;width:100%;bottom:0;background-image:linear-gradient(to bottom,transparent,#ffffff 50px);} < style data-emotion-css="12cv0pi"> .css-12cv0pi{box-sizing:border-box;margin:0;min-width:0;height:100px;-webkit-box-pack:center;-webkit-justify-content:center;-ms-flex-pack:center;justify-content:center;display:-webkit-box;display:-webkit-flex;display:-ms-flexbox;display:flex;position:absolute;width:100%;bottom:0;background-image:linear-gradient(to bottom,transparent,#ffffff 50px);}
< style data-emotion-css="1pr2waf"> .css-1pr2waf{font-size:15px;color:#09408e;}
发布于 2021-07-19 19:40
< style data-emotion-css="ch8ocw"> .css-ch8ocw{position:relative;display:inline-block;height:30px;padding:0 12px;font-size:14px;line-height:30px;color:#1772F6;vertical-align:top;border-radius:100px;background:rgba(23,114,246,0.1);}.css-ch8ocw:hover{background-color:rgba(23,114,246,0.15);}
< style data-emotion-css="1xlfegr"> .css-1xlfegr{background:transparent;box-shadow:none;} < style data-emotion-css="1gomreu"> .css-1gomreu{position:relative;display:inline-block;}
残差网络