mirror of
https://github.com/VikParuchuri/surya.git
synced 2026-06-04 21:03:53 +08:00
Fix weight tying for new output head
Some checks failed
Integration test / build (push) Has been cancelled
Unit tests / build (t4_gpu) (push) Has been cancelled
Unit tests / build (ubuntu-latest) (push) Has been cancelled
Unit tests / build (windows-latest) (push) Has been cancelled
Test CLI scripts / build (push) Has been cancelled
Some checks failed
Integration test / build (push) Has been cancelled
Unit tests / build (t4_gpu) (push) Has been cancelled
Unit tests / build (ubuntu-latest) (push) Has been cancelled
Unit tests / build (windows-latest) (push) Has been cancelled
Test CLI scripts / build (push) Has been cancelled
This commit is contained in:
parent
0d563ee949
commit
8078598573
@ -104,7 +104,7 @@ class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel):
|
||||
_supports_static_cache = True
|
||||
_supports_attention_backend = True
|
||||
main_input_name = "input_ids"
|
||||
_tied_weights_keys = ["lm_head.weight"]
|
||||
_tied_weights_keys = ["lm_head.out_proj.weight"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -165,16 +165,16 @@ class SuryaModel(S3DownloaderMixin, SuryaPreTrainedModel):
|
||||
|
||||
def _tie_weights(self):
|
||||
# Tie weights of lm head and token embedder
|
||||
self._tie_or_clone_weights(self.lm_head, self.embedder.token_embed)
|
||||
self._tie_or_clone_weights(self.lm_head.out_proj, self.embedder.token_embed)
|
||||
|
||||
def get_output_embeddings(self) -> nn.Module:
|
||||
return self.lm_head
|
||||
return self.lm_head.out_proj
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.embedder.token_embed
|
||||
|
||||
def set_output_embeddings(self, new_embeddings: nn.Module):
|
||||
self.lm_head = new_embeddings
|
||||
self.lm_head.out_proj = new_embeddings
|
||||
|
||||
def set_input_embeddings(self, new_embeddings: nn.Module):
|
||||
self.embedder.token_embed = new_embeddings
|
||||
|
||||
Loading…
Reference in New Issue
Block a user