當前位置:網站首頁>*精度優化*優化策略1:網絡+SAM優化器
*精度優化*優化策略1:網絡+SAM優化器
2022-07-23 05:01:12【夏天|여름이다】
一:SAM優化器介紹:
SAM:Sharpness Awareness Minimization銳度感知最小化
SAM不是一個新的優化器,它與其他常見的優化器一起使用,比如SGD/Adam。
論文:2020 Sharpness-Aware Minimization for Efficiently Improving Generalization
論文地址:https://arxiv.org/pdf/2010.01412v2.pdf
項目地址:GitHub - davda54/sam: SAM: Sharpness-Aware Minimization (PyTorch)
(依舊建議大家使用GPU去訓練,一般電腦cpu可以運行,但是非常卡,能卡出數據集,但是沒卡出結果。)
下載解壓後非常簡單,把sam.py文件直接複制到example文件夾下就可以直接跑train.py.
運行後會自動下載數據集,會進行批次訓練。
運行結果:(我改的epochs比較小,改大效果更好)
重要部分如下train.py:
import argparse
import torch
from model.wide_res_net import WideResNet#導入模型中的wide_res_net網絡
from model.smooth_cross_entropy import smooth_crossentropy#導入損失函數
from data.cifar import Cifar#導入數據集
from utility.log import Log#導入工具類日志文件
from utility.initialize import initialize#導入工具類初始化
from utility.step_lr import StepLR#導入工具類階梯學習率
from utility.bypass_bn import enable_running_stats, disable_running_stats#導入工具類繞過BN,啟用運行統計,禁用運行統計
import sys; sys.path.append("..")#導入sys.path中需要用到的XXX包,然後加載
from sam import SAM#引入SAM
if __name__ == "__main__":
#創建解析器(arg對象)
parser = argparse.ArgumentParser()
#添加參數
parser.add_argument("--adaptive", default=True, type=bool, help="True if you want to use the Adaptive SAM.")
parser.add_argument("--batch_size", default=12, type=int, help="Batch size used in the training and validation loop.")
parser.add_argument("--depth", default=16, type=int, help="Number of layers.")
parser.add_argument("--dropout", default=0.0, type=float, help="Dropout rate.")
parser.add_argument("--epochs", default=2, type=int, help="Total number of epochs.")
parser.add_argument("--label_smoothing", default=0.1, type=float, help="Use 0.0 for no label smoothing.")
parser.add_argument("--learning_rate", default=0.1, type=float, help="Base learning rate at the start of the training.")
parser.add_argument("--momentum", default=0.9, type=float, help="SGD Momentum.")
parser.add_argument("--threads", default=2, type=int, help="Number of CPU threads for dataloaders.")
parser.add_argument("--rho", default=2.0, type=int, help="Rho parameter for SAM.")
parser.add_argument("--weight_decay", default=0.0005, type=float, help="L2 weight decay.")
parser.add_argument("--width_factor", default=8, type=int, help="How many times wider compared to normal ResNet.")
#解析參數
args = parser.parse_args()
#初始化
initialize(args, seed=42)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#定義數據集
dataset = Cifar(args.batch_size, args.threads)
#記錄日志
log = Log(log_each=10)
#定義模型
model = WideResNet(args.depth, args.width_factor, args.dropout, in_channels=3, labels=10).to(device)
#定義基礎優化器
base_optimizer = torch.optim.SGD
#定義第二個優化器SAM
optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
#將optimizer作為參數傳遞給scheduler,每次通過調用scheduler.step()就會更新optimizer中每一個param_group[‘lr’],每過固定個epoch,學習率會按照gamma倍率進行衰减。
scheduler = StepLR(optimizer, args.learning_rate, args.epochs)
for epoch in range(args.epochs):
model.train()
log.train(len_dataset=len(dataset.train))
for batch in dataset.train:
inputs, targets = (b.to(device) for b in batch)
# first forward-backward step
enable_running_stats(model)
predictions = model(inputs)
loss = smooth_crossentropy(predictions, targets, smoothing=args.label_smoothing)
loss.mean().backward()
optimizer.first_step(zero_grad=True)
# second forward-backward step
disable_running_stats(model)
smooth_crossentropy(model(inputs), targets, smoothing=args.label_smoothing).mean().backward()
optimizer.second_step(zero_grad=True)
with torch.no_grad():
correct = torch.argmax(predictions.data, 1) == targets
log(model, loss.cpu(), correct.cpu(), scheduler.lr())
scheduler(epoch)
model.eval()
log.eval(len_dataset=len(dataset.test))
with torch.no_grad():
for batch in dataset.test:
inputs, targets = (b.to(device) for b in batch)
predictions = model(inputs)
loss = smooth_crossentropy(predictions, targets)
correct = torch.argmax(predictions, 1) == targets
log(model, loss.cpu(), correct.cpu())
log.flush()
二:把SAM應用到自己的項目上:
step1:把SAM的工具文件複制到自己的項目下
把utility文件夾複制到自己的項目下,
把sam.py複制到項目根目錄,在train.py裏導入包。
step2:把數據集改為自己的數據集
step3:把網絡改為自己的網絡
(我的項目是多個獨立的網絡,幾個網絡就寫幾遍)
step4:添加基礎優化器和SAM
base_optimizer = torch.optim.SGD
optimizer = SAM(model.parameters(), base_optimizer, rho=args.rho, adaptive=args.adaptive, lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = StepLR(optimizer, args.learning_rate, args.epochs)
#一定要根據自己的項目去改相關參數等step5:把損失函數改為自己原本的損失函數添加SAM工具類
...
#opt.zero_grad()注釋掉原本的
#添加SAM工具類裏的函數
enable_running_stats(model_context)
enable_running_stats(model_body)
enable_running_stats(emotic_model)
#我的項目是三個網絡。如果是一個網絡的話,寫一次
#類似於enable_running_stats(model)
...
loss.backward()
opt.first_step(zero_grad=True)#在項目的loss反向傳播後先用優化器first step
#添加SAM工具類裏的函數
disable_running_stats(model_context)
disable_running_stats(model_body)
disable_running_stats(emotic_model)
...
loss.backward()
opt.second_step(zero_grad=True)#在項目的loss反向傳播後再用優化器second step
# opt.step()注釋掉原本的
step6:添加SAM所需的超參數(可選,不改也不會出錯)
原項目:
修改後:
原項目:
修改後:
#黃色框為修改比特置
加入SAM優化器後,比原來精度提高了將近3%。
以上。(全是自己的理解,不正確望指正,感謝。)
版權聲明
本文為[夏天|여름이다]所創,轉載請帶上原文鏈接,感謝
https://cht.chowdera.com/2022/204/202207221752585451.html
邊欄推薦
猜你喜歡
[論文翻譯] Generalized Radiograph Representation Learning via Cross-Supervision Between Images
codeforce D2. RGB Substring (hard version) 滑動窗口
服務器buffer/cache 的產生原因和釋放buffer/cache
NFS共享存儲服務
MySQL 增删改查(進階)
十七、C函數指針與回調函數
QT筆記—— QTableWidget 之 拖拽行數 和 移動
女嘉賓報名
MySQL密碼正確但是啟動報錯Unable to create initial connections of pool.Access denied for user ‘root‘@‘localhost
【SDIO】SD2.0協議分析總結(三)-- SD卡相關命令介紹
隨機推薦
- App移動端測試【6】應用程序(apk)包管理與activity
- Qt | 模態對話框和非模態對話框 QDialog
- 在各類數據庫中隨機查詢n條數據
- 二、IDEA搭建JFinal項目+代碼自動生成+數據庫操作測試(三種方式)
- Flutter 第一個程序Hello World!
- 派生類的構造函數和析構函數
- NewSQL數據庫數據模型設計
- 2017年終總結
- dns劫持如何完美修複?dns被劫持如何解决如何完美修複
- flask 跨域
- 鏈棧實現(C語言)
- DETR 論文精讀,並解析模型結構
- 【FPGA】:ip核--DDR3
- 微信小程序Cannot read property 'setData' of null錯誤
- BUUCTF闖關日記--[網鼎杯 2020 青龍組]AreUSerialz
- 嵌入式系統學習筆記
- 水庫河道應急廣播系統解决方案
- Cartesi 2022 年 3 月回顧
- Daily Leetcode-11 分治
- 智源社區AI周刊#90:馬毅認為智能不可能依賴大算力實現;Hugging Face博客揭秘Bloom訓練細節;ICML最佳論文獎公布
- TypeScript
- 開源工具 SAP UI5 Tools 介紹
- Lark教程指南
- 網絡安全——使用Evil Maid物理訪問安全漏洞進行滲透
- 網絡安全—使用Ubuntu本地提權漏洞進行滲透及加固
- JWT工具類編寫
- Day1 Running Sum of 1d Array/Find Pivot Index/用兩個棧實現隊列
- socket編程之常用api介紹與socket、select、poll、epoll高並發服務器模型代碼實現
- 深入研究容器隊列
- Bean的初始化回調方法和釋放資源的回調方法
- 爬蟲數據保存到mysql數據庫
- 通過SQL進行數據分發
- Redis 分布式鎖如何自動續期(經典解决方案)
- 虹科動態 | cippe2022即將舉辦,報名火熱進行中
- Kotlin之匿名內部類(object: xxxx)
- 面試突擊:truncate、delete和drop的6大區別
- Ubuntu安裝Docker及Docker的基本命令 安裝MySQL
- LeetCode--棧和隊列篇
- etcd 集群部署
- TCP/IP協議族中需要必知必會的十大問題