[Pytorch 에러] RuntimeError: Error(s) in loading state_dict - size mismatch for conv1.weight
Pytorch를 활용하여 Pretrain된 가중치를 기반으로 Transfer Learning을 수행하는 경우가 많이 있을텐데요.
torchvision과 같은 라이브러리에서는 Pretrained=True와 같은 인자를 지원하는 경우가 있지만, 연구를 하다보면 개인 연구자들이 올려준 Github로부터 모델 소스와 .pt(h) 확장자로 되어있는 파일을 불러오는 경우가 있을 겁니다.
그럴 때, 일반적으로 아래와 같은 코드로 Initialize를 수행하는데요.
state_dict = torch.load(weight_path)
model.load_state_dict(state_dict, strict=True)
그런데 해당 가중치를 가져와서 다른 태스크를 수행한다고 하면 일부 레이어(보통 마지막 FC 레이어)를 제거하고 새로운 레이어를 추가할텐데요. 이러한 경우 위 인자의 strict=True를 False로 바꾸면 에러없이 바뀐 레이어를 제외한 나머지 레이어들의 가중치가 잘 로드됩니다.
그러나 제 경우에 strict=False로 두어도 아래와 같은 에러가 나는 것을 경험했습니다.
따라서 이를 해결하기 위해 아래와 같이 기존 가중치의 patch_embed.proj라고 하는 레이어의 키 명칭을 바꾸어주면 해당 레이어를 제외한 나머지 가중치가 잘 로드되는 것을 확인하였습니다.
state_dict = torch.load(weight_path)
tmp_dict = OrderedDict()
for i, j in state_dict.items(): # 가중치의 모든 키 값 반복문
name = i.replace("embed_proj","") # 매치되지 않는 키 값 변경
tmp_dict[name] = j
model.load_state_dict(tmp_dict, strict=False)