-
[Pytorch 에러] RuntimeError: Error(s) in loading state_dict - size mismatch for conv1.weightProgramming Error/PyTorch 2023. 3. 2. 19:46
Pytorch를 활용하여 Pretrain된 가중치를 기반으로 Transfer Learning을 수행하는 경우가 많이 있을텐데요.
torchvision과 같은 라이브러리에서는 Pretrained=True와 같은 인자를 지원하는 경우가 있지만, 연구를 하다보면 개인 연구자들이 올려준 Github로부터 모델 소스와 .pt(h) 확장자로 되어있는 파일을 불러오는 경우가 있을 겁니다.
그럴 때, 일반적으로 아래와 같은 코드로 Initialize를 수행하는데요.
state_dict = torch.load(weight_path) model.load_state_dict(state_dict, strict=True)
그런데 해당 가중치를 가져와서 다른 태스크를 수행한다고 하면 일부 레이어(보통 마지막 FC 레이어)를 제거하고 새로운 레이어를 추가할텐데요. 이러한 경우 위 인자의 strict=True를 False로 바꾸면 에러없이 바뀐 레이어를 제외한 나머지 레이어들의 가중치가 잘 로드됩니다.
그러나 제 경우에 strict=False로 두어도 아래와 같은 에러가 나는 것을 경험했습니다.
여기서 기존에 Pretrained할 때는 patch_embed.proj라고 하는 Layer의 Convolution size가 (16, 16)이었는데, 제가 새롭게 사이즈를 (128, 2)로 변경하였을 때 위와 같은 에러가 발생하였습니다. stackoverflow에 잘 검색해보니 load_state_dict는 key값 (즉, patch_embed.proj)을 통해 매칭이 되고 strict=False인 경우 매칭되는 해당 key 값이 없으면 무시하게 되는데patch_embed.proj라고 하는 key는 가중치 파일에도, 그리고 제 모델의 구조 양쪽에 존재하기 때문에 strict 인자로는 무시가 되지 않았습니다.
따라서 이를 해결하기 위해 아래와 같이 기존 가중치의 patch_embed.proj라고 하는 레이어의 키 명칭을 바꾸어주면 해당 레이어를 제외한 나머지 가중치가 잘 로드되는 것을 확인하였습니다.
state_dict = torch.load(weight_path) tmp_dict = OrderedDict() for i, j in state_dict.items(): # 가중치의 모든 키 값 반복문 name = i.replace("embed_proj","") # 매치되지 않는 키 값 변경 tmp_dict[name] = j model.load_state_dict(tmp_dict, strict=False)
'Programming Error > PyTorch' 카테고리의 다른 글