Fix weight tying for new output head
Some checks failed
Integration test / build (push) Has been cancelled
Unit tests / build (t4_gpu) (push) Has been cancelled
Unit tests / build (ubuntu-latest) (push) Has been cancelled
Unit tests / build (windows-latest) (push) Has been cancelled
Test CLI scripts / build (push) Has been cancelled

This commit is contained in:
Tarun Menta 2025-10-08 13:11:02 -04:00
parent 0d563ee949
commit 8078598573
No known key found for this signature in database

View File

@ -104,7 +104,7 @@ class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel):
_supports_static_cache = True
_supports_attention_backend = True
main_input_name = "input_ids"
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = ["lm_head.out_proj.weight"]
def __init__(
self,
@ -165,16 +165,16 @@ class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel):
def _tie_weights(self):
# Tie weights of lm head and token embedder
self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed)
self._tie_or_clone_weights(self.lm_head.out_proj, self.embedder.token_embed)
def get_output_embeddings(self) -> nn.Module:
return self.lm_head
return self.lm_head.out_proj
def get_input_embeddings(self) -> nn.Module:
return self.embedder.token_embed
def set_output_embeddings(self, new_embeddings: nn.Module):
self.lm_head = new_embeddings
self.lm_head.out_proj = new_embeddings
def set_input_embeddings(self, new_embeddings: nn.Module):
self.embedder.token_embed = new_embeddings