Skip to content

Commit

Permalink
style(pytorch_poc): format
Browse files Browse the repository at this point in the history
  • Loading branch information
tpoisonooo committed Nov 3, 2023
1 parent ae19e01 commit 1cd516a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
5 changes: 2 additions & 3 deletions lmdeploy/pytorch_poc/kernels/rerope_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _rerope_fwd_kernel(
qk += tl.dot(q2, k2, out_dtype=tl.float32)
elif start_n > (
start_m + 1
) * BLOCK_M - WINDOW and start_n < start_m * BLOCK_M + WINDOW - BLOCK_N:
) * BLOCK_M - WINDOW and start_n < start_m * BLOCK_M + WINDOW - BLOCK_N: # noqa: E501
k1 = tl.load(K1_block_ptr)
v = tl.load(V_block_ptr)
qk += tl.dot(q1, k1, out_dtype=tl.float32)
Expand Down Expand Up @@ -328,9 +328,8 @@ def f(fn, q1, q2, k1, k2, v, sm_scale, window):
import time
begin = time.time()
LOOP = 100
for i in range(LOOP):
for _ in range(LOOP):
rerope_attention_fwd(q1, q2, k1, k2, v, True, sm_scale, WINDOW)
timecost = (time.time() - begin) / LOOP
print(time.time() - begin)


Expand Down
11 changes: 6 additions & 5 deletions lmdeploy/pytorch_poc/passkey_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from lmdeploy.tokenizer import Tokenizer

os.environ['TM_LOG_LEVEL'] = 'ERROR'
import pdb


class LLM(object):
Expand Down Expand Up @@ -85,7 +84,7 @@ def parse_config():
return args


# copy from https://github.com/dvlab-research/LongLoRA/blob/main/passkey_retrivial.py
# copy from https://github.com/dvlab-research/LongLoRA/blob/main/passkey_retrivial.py # noqa: E501
def generate_prompt_landmark(n_garbage=60000, seed=666):
"""Generates a text file and inserts an passkey at a random position."""
from numpy import random as nprandom
Expand All @@ -94,14 +93,14 @@ def generate_prompt_landmark(n_garbage=60000, seed=666):
n_garbage_prefix = nprandom.randint(0, n_garbage)
n_garbage_suffix = n_garbage - n_garbage_prefix

task_description = 'There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.'
garbage = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.'
task_description = 'There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.' # noqa: E501
garbage = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.' # noqa: E501
garbage_inf = ' '.join([garbage] * 5000)
assert len(garbage_inf) >= n_garbage
garbage_prefix = garbage_inf[:n_garbage_prefix]
garbage_suffix = garbage_inf[:n_garbage_suffix]
pass_key = nprandom.randint(1, 50000)
information_line = f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.'
information_line = f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.' # noqa: E501
final_question = 'What is the pass key? The pass key is'
lines = [
task_description,
Expand All @@ -111,6 +110,7 @@ def generate_prompt_landmark(n_garbage=60000, seed=666):
final_question,
]
nprandom.set_state(rnd_state)

return '\n'.join(lines), str(pass_key)


Expand Down Expand Up @@ -145,6 +145,7 @@ def main(args):
accuracy = passed_tests / args.num_tests
print('accuracy on the token length %d is %f' % (avg_tokens, accuracy))
all_accuries[str(avg_tokens)] = accuracy

print('accuries over tokens', all_accuries)


Expand Down

0 comments on commit 1cd516a

Please sign in to comment.