Skip to content

Commit

Permalink
forward partial for classification purposes
Browse files Browse the repository at this point in the history
  • Loading branch information
timlacroix committed Oct 17, 2023
1 parent e04eae8 commit a165337
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions mistral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def dtype(self) -> torch.dtype:
def device(self) -> torch.device:
return self.tok_embeddings.weight.device

def forward(
def forward_partial(
self,
input_ids: torch.Tensor,
cache: RotatingBufferCache,
Expand All @@ -210,7 +210,17 @@ def forward(

cache.update_seqlens(seqlens)

return self.output(self.norm(h)).float()
return self.norm(h)

def forward(
self,
input_ids: torch.Tensor,
cache: RotatingBufferCache,
seqlens: List[int],
) -> torch.Tensor:
return self.output(self.forward_partial(
input_ids, cache, seqlens
)).float()

@staticmethod
def from_folder(folder: Path, max_batch_size: int = 1, device="cuda", dtype=torch.float16) -> "Transformer":
Expand Down

0 comments on commit a165337

Please sign in to comment.