From 8078598573d7b211dad7fac2d0735274c273aea8 Mon Sep 17 00:00:00 2001 From: Tarun Menta Date: Wed, 8 Oct 2025 13:11:02 -0400 Subject: [PATCH] Fix weight tying for new output head --- surya/common/surya/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/surya/common/surya/__init__.py b/surya/common/surya/__init__.py index 3b2472a..dca9562 100644 --- a/surya/common/surya/__init__.py +++ b/surya/common/surya/__init__.py @@ -104,7 +104,7 @@ class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel): _supports_static_cache = True _supports_attention_backend = True main_input_name = "input_ids" - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = ["lm_head.out_proj.weight"] def __init__( self, @@ -165,16 +165,16 @@ class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel): def _tie_weights(self): # Tie weights of lm head and token embedder - self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed) + self._tie_or_clone_weights(self.lm_head.out_proj, self.embedder.token_embed) def get_output_embeddings(self) -> nn.Module: - return self.lm_head + return self.lm_head.out_proj def get_input_embeddings(self) -> nn.Module: return self.embedder.token_embed def set_output_embeddings(self, new_embeddings: nn.Module): - self.lm_head = new_embeddings + self.lm_head.out_proj = new_embeddings def set_input_embeddings(self, new_embeddings: nn.Module): self.embedder.token_embed = new_embeddings