Pytorch通过requires_grad固定部分参数进行网络训练
# 1. 只训练部分层 class RESNET_attention(nn.Module): def __init__(self, model, pretrained): super(RESNET_attetnion, self).__init__() self.resnet = model(pretrained) for p in self.parameters(): p.requires_grad = False self.f = nn.Conv2d(2048, 512, 1) self.g = nn.Conv2d(2048, 512, 1) self.h =...
more...