pytorch 固定部分参数训练

Miracle
1002
文章
46
评论
2019年5月6日21:05:59 评论 796字阅读2分39秒

应用场景: 在加载预训练模型之后,在原来的基础上添加一部分的网络,我们可以固定原来的参数,然后只训练我们添加的这部分网络,完了之后再全部训练.

  1. class RESNET_attention(nn.Module):  
  2. def __init__(self, model, pretrained):  
  3. super(RESNET_attetnion, self).__init__()  
  4. self.resnet = model(pretrained)  
  5. for p in self.parameters():  
  6. p.requires_grad = False  
  7. self.f = nn.Conv2d(2048, 512, 1)  
  8. self.g = nn.Conv2d(2048, 512, 1)  
  9. self.h = nn.Conv2d(2048, 2048, 1)  
  10. self.softmax = nn.Softmax(-1)  
  11. self.gamma = nn.Parameter(torch.FloatTensor([0.0]))  
  12. self.avgpool = nn.AvgPool2d(7, stride=1)  
  13. self.resnet.fc = nn.Linear(2048, 10)  

这样就将for循环以上的参数固定, 只训练下面的参数(f,g,h,gamma,fc,等), 但是注意需要在optimizer中添加上这样的一句话filter(lambda p: p.requires_grad, model.parameters()
添加的位置为:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)

  • 版权声明: 发表于 2019年5月6日21:05:59
  • 转载注明:https://x1995.cn/5200.html
利用深度学习自动补全 Python 代码 干货教程

利用深度学习自动补全 Python 代码

代码补全功能在IDE里面十分常见,优秀的代码自动补全功能可以大大提升工作效率。不过, IDE 基本都使用搜索方法进行补全,在一些场景下效果不佳。今日,猿妹在GitHub上找到一个开源项目,使用深度学习...
分享8点超级有用的Python编程建议 干货教程

分享8点超级有用的Python编程建议

我们在用Python进行机器学习建模项目的时候,每个人都会有自己的一套项目文件管理的习惯,我自己也有一套方法,是自己曾经踩过的坑总结出来的,现在在这里分享一下给大家,希望多少有些地方可以给大家借鉴。 ...
一文读懂NFC 干货教程

一文读懂NFC

相信大家都知道NFC,如今几乎所有旗舰手机都支持这一功能,有了它,我们就可以用手机来乘坐公交地铁、用手机开门、用手机支付,那么你知道NFC到底是什么吗?IT之家现在给大家简单科普一下。 NFC全称为“...
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: