pytorch load_state_dict
-
[Pytorch 에러] RuntimeError: Error(s) in loading state_dict - size mismatch for conv1.weightProgramming Error/PyTorch 2023. 3. 2. 19:46
Pytorch를 활용하여 Pretrain된 가중치를 기반으로 Transfer Learning을 수행하는 경우가 많이 있을텐데요. torchvision과 같은 라이브러리에서는 Pretrained=True와 같은 인자를 지원하는 경우가 있지만, 연구를 하다보면 개인 연구자들이 올려준 Github로부터 모델 소스와 .pt(h) 확장자로 되어있는 파일을 불러오는 경우가 있을 겁니다. 그럴 때, 일반적으로 아래와 같은 코드로 Initialize를 수행하는데요. state_dict = torch.load(weight_path) model.load_state_dict(state_dict, strict=True) 그런데 해당 가중치를 가져와서 다른 태스크를 수행한다고 하면 일부 레이어(보통 마지막 FC 레이어)를 제거..