diff --git a/src/trainer/model_with_neck.py b/src/trainer/model_with_neck.py index ceec383..2f36f39 100644 --- a/src/trainer/model_with_neck.py +++ b/src/trainer/model_with_neck.py @@ -585,7 +585,7 @@ class MoEModel(nn.Module): state_dict = torch.load( state_dict_path, weights_only=True, map_location=self.device ) - self.model.load_state_dict(state_dict) + self.load_state_dict(state_dict) def load_from_pretrained_base_model( self,