Skip to content

Commit

Permalink
Implement static cache
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 24, 2024
1 parent f42cfa7 commit 17ef170
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 43 deletions.
27 changes: 17 additions & 10 deletions surya/model/recognition/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def forward(
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
if is_cross_attention:
past_key_value = (key_states, value_states)
else:
past_key_value = (key_states[:, :, -tgt_len:], value_states[:, :, -tgt_len:])

proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
Expand Down Expand Up @@ -290,11 +293,11 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
langs: Optional[torch.LongTensor] = None,
kv_caches: Optional[List[torch.Tensor]] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
) -> torch.Tensor:
Expand All @@ -303,7 +306,7 @@ def forward(

# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attn_past_key_value = kv_caches[0] if kv_caches is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
Expand All @@ -323,7 +326,7 @@ def forward(
hidden_states = self.encoder_attn_layer_norm(hidden_states)

# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attn_past_key_value = kv_caches[1] if kv_caches is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
Expand Down Expand Up @@ -393,12 +396,13 @@ def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_caches: Optional[List[torch.Tensor]] = None,
past_token_count: Optional[int] = None,
langs: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -426,7 +430,7 @@ def forward(
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")

# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
past_key_values_length = past_token_count if kv_caches is not None else 0

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
Expand Down Expand Up @@ -485,18 +489,18 @@ def forward(
if dropout_probability < self.layerdrop:
continue

past_key_value = past_key_values[idx] if past_key_values is not None else None
kv_cache = [kv_caches[0][idx], kv_caches[1][idx]] if kv_caches is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
langs=langs,
kv_caches=kv_cache,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
Expand Down Expand Up @@ -567,12 +571,14 @@ def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_caches: Optional[List[torch.FloatTensor]] = None,
past_token_count: Optional[int] = None,
langs: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand All @@ -590,12 +596,13 @@ def forward(
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
kv_caches=kv_caches,
past_token_count=past_token_count,
langs=langs,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down
104 changes: 71 additions & 33 deletions surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,41 +48,79 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
batch_decoder_input = torch.from_numpy(np.array(batch_decoder_input, dtype=np.int64)).to(model.device)

token_count = 0
encoder_outputs = None
batch_predictions = [[] for _ in range(len(batch_images))]
sequence_scores = None

attention_mask = torch.ones_like(batch_decoder_input, device=model.device)
all_done = torch.zeros(len(batch_images), dtype=torch.bool, device=model.device)

# Decoder kv cache
# 7 (layers) x 2 (kv) x bs x 4 (heads) x max tokens x 64 (head dim)
dec_config = model.config.decoder
layer_count = dec_config.decoder_layers
kv_heads = dec_config.kv_heads
head_dim = int(dec_config.d_model / dec_config.decoder_attention_heads)
decoder_cache = torch.zeros((layer_count, 2, len(batch_images), kv_heads, settings.RECOGNITION_MAX_TOKENS, head_dim), dtype=model.dtype, device=model.device)
kv_mask = torch.zeros((len(batch_images), settings.RECOGNITION_MAX_TOKENS), device=model.device)

# Encoder kv cache
# 7 (layers) x 2 (kv) x bs x 4 (heads) x 196 (max tokens) x 64 (head dim)
encoder_cache = torch.zeros((layer_count, 2, len(batch_images), kv_heads, 196, head_dim), dtype=model.dtype, device=model.device)

with torch.inference_mode():
return_dict = model.generate(
pixel_values=batch_pixel_values,
decoder_input_ids=batch_decoder_input,
decoder_langs=batch_langs,
eos_token_id=processor.tokenizer.eos_id,
pad_token_id=processor.tokenizer.pad_token_id,
max_new_tokens=settings.RECOGNITION_MAX_TOKENS,
output_scores=True,
return_dict_in_generate=True
)
generated_ids = return_dict["sequences"]

# Find confidence scores
scores = return_dict["scores"] # Scores is a tuple, one per new sequence position. Each tuple element is bs x vocab_size
sequence_scores = torch.zeros(generated_ids.shape[0])
sequence_lens = torch.where(
generated_ids > processor.tokenizer.eos_id,
torch.ones_like(generated_ids),
torch.zeros_like(generated_ids)
).sum(axis=-1).cpu()
prefix_len = generated_ids.shape[1] - len(scores) # Length of passed in tokens (bos, langs)
for token_idx, score in enumerate(scores):
probs = F.softmax(score, dim=-1)
max_probs = torch.max(probs, dim=-1).values
max_probs = torch.where(
generated_ids[:, token_idx + prefix_len] <= processor.tokenizer.eos_id,
torch.zeros_like(max_probs),
max_probs
).cpu()
sequence_scores += max_probs
sequence_scores /= sequence_lens

detected_text = processor.tokenizer.batch_decode(generated_ids)
while token_count < settings.RECOGNITION_MAX_TOKENS:
inference_token_count = batch_decoder_input.shape[-1]
return_dict = model(
decoder_input_ids=batch_decoder_input,
decoder_attention_mask=attention_mask,
decoder_kv_caches=None if token_count == 0 else [decoder_cache, encoder_cache],
decoder_past_token_count=token_count,
decoder_langs=batch_langs,
pixel_values=batch_pixel_values,
encoder_outputs=encoder_outputs,
return_dict=True,
)

logits = return_dict["logits"]
preds = torch.argmax(logits[:, -1], dim=-1)
scores = torch.max(F.softmax(logits, dim=-1), dim=-1).values
done = preds == processor.tokenizer.eos_id
all_done = all_done | done

if sequence_scores is None:
sequence_scores = scores
else:
scores[all_done == 1] = 0
sequence_scores = torch.cat([sequence_scores, scores], dim=1)

encoder_outputs = (return_dict["encoder_last_hidden_state"],)
past_key_values = return_dict["past_key_values"]
for layer_idx, layer in enumerate(past_key_values):
decoder_cache[layer_idx, 0, :, :, token_count:(token_count + inference_token_count), :] = layer[0]
decoder_cache[layer_idx, 1, :, :, token_count:(token_count + inference_token_count), :] = layer[1]

encoder_cache[layer_idx, 0, :, :, :, :] = layer[2]
encoder_cache[layer_idx, 1, :, :, :, :] = layer[3]

if all_done.all():
break

kv_mask[:, token_count:(token_count + inference_token_count)] = 1
attention_mask = torch.cat([kv_mask, ~all_done.unsqueeze(1)], dim=1)

for j, (pred, status) in enumerate(zip(preds, all_done)):
if not status:
batch_predictions[j].append(int(pred))

batch_decoder_input = preds.unsqueeze(1)
token_count += inference_token_count

sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
detected_text = processor.tokenizer.batch_decode(batch_predictions)
detected_text = [truncate_repetitions(dt) for dt in detected_text]

# Postprocess to fix LaTeX output (add $$ signs, etc)
detected_text = [fix_math(text) if math and contains_math(text) else text for text, math in zip(detected_text, has_math)]
output_text.extend(detected_text)
Expand Down

0 comments on commit 17ef170

Please sign in to comment.