一、两种模式
pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval()。
一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。
(资料图片仅供参考)
二、功能
1. model.train()
在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。
如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。
model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。
2. model.eval()
model.eval()的作用是 不启用 Batch Normalization 和 Dropout。
如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。
model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。
为什么测试时要用 model.eval() ?
训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。
eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。
也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout。
3. 总结与对比
如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。
其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;
而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。
三、Dropout 简介
dropout 常常用于抑制过拟合。
设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。
到此这篇关于详解model.train()和model.eval()两种模式的原理与用法的文章就介绍到这了,更多相关model.train()和model.eval()原理用法内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
关键词:
X 关闭
X 关闭
- 15G资费不大降!三大运营商谁提供的5G网速最快?中国信通院给出答案
- 2联想拯救者Y70发布最新预告:售价2970元起 迄今最便宜的骁龙8+旗舰
- 3亚马逊开始大规模推广掌纹支付技术 顾客可使用“挥手付”结账
- 4现代和起亚上半年出口20万辆新能源汽车同比增长30.6%
- 5如何让居民5分钟使用到各种设施?沙特“线性城市”来了
- 6AMD实现连续8个季度的增长 季度营收首次突破60亿美元利润更是翻倍
- 7转转集团发布2022年二季度手机行情报告:二手市场“飘香”
- 8充电宝100Wh等于多少毫安?铁路旅客禁止、限制携带和托运物品目录
- 9好消息!京东与腾讯续签三年战略合作协议 加强技术创新与供应链服务
- 10名创优品拟通过香港IPO全球发售4100万股 全球发售所得款项有什么用处?