diff --git a/surya/common/surya/__init__.py b/surya/common/surya/__init__.py index 3b2472a..dca9562 100644 --- a/surya/common/surya/__init__.py +++ b/surya/common/surya/__init__.py @@ -104,7 +104,7 @@ class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel): _supports_static_cache = True _supports_attention_backend = True main_input_name = "input_ids" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = ["lm_head.out_proj.weight"] def __init__( self, @@ -165,16 +165,16 @@ class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel): def _tie_weights(self): # Tie weights of lm head and token embedder - self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed) + self._tie_or_clone_weights(self.lm_head.out_proj, self.embedder.token_embed) def get_output_embeddings(self) -> nn.Module: - return self.lm_head + return self.lm_head.out_proj def get_input_embeddings(self) -> nn.Module: return self.embedder.token_embed def set_output_embeddings(self, new_embeddings: nn.Module): - self.lm_head = new_embeddings + self.lm_head.out_proj = new_embeddings def set_input_embeddings(self, new_embeddings: nn.Module): self.embedder.token_embed = new_embeddings