Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Logprobs support in Multi-step #7652

Merged
merged 49 commits into from
Aug 30, 2024
Merged
Changes from 1 commit
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
f97e0ae
added example
afeldman-nm Aug 21, 2024
f969241
wip:
afeldman-nm Aug 21, 2024
642d31b
first working attempt at logprobs
afeldman-nm Aug 21, 2024
a0ca262
merge; format
afeldman-nm Aug 21, 2024
ed97288
passing test; dataclass
afeldman-nm Aug 21, 2024
861e1b9
refactoring
afeldman-nm Aug 21, 2024
8bc0765
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
a34d1ac
refactoring
afeldman-nm Aug 21, 2024
4cda5c0
Merge branch 'logprobs' into logprobs_merge
afeldman-nm Aug 21, 2024
ac8a39a
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
1284327
removing example
afeldman-nm Aug 21, 2024
a6c1207
removed example from build pipeline
afeldman-nm Aug 21, 2024
fe42995
fixed one docstring; embedded NUM_LOGPROBS
afeldman-nm Aug 21, 2024
9fb5bbe
test refactor
afeldman-nm Aug 21, 2024
046a8b1
incremental refactors
afeldman-nm Aug 21, 2024
fa86efd
remove unnecessary conftest change
afeldman-nm Aug 21, 2024
1c0ffb6
Update vllm/model_executor/layers/sampler.py
afeldman-nm Aug 21, 2024
3babadb
refactor
afeldman-nm Aug 21, 2024
f502029
Merge branch 'afeldman-nm/logprobs' of https://github.com/neuralmagic…
afeldman-nm Aug 21, 2024
1875b37
test_multi_step comment
afeldman-nm Aug 21, 2024
3760a95
utils function docstrings
afeldman-nm Aug 21, 2024
d43308c
docstring refactors
afeldman-nm Aug 21, 2024
54db498
merge
afeldman-nm Aug 21, 2024
dfbbaf0
passing tests & formatted
afeldman-nm Aug 21, 2024
5eebfca
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 21, 2024
5e23d9a
Merge branch 'main' into logprobs_merge
afeldman-nm Aug 22, 2024
717efa3
merge; format
afeldman-nm Aug 22, 2024
e0d59ce
removed incorrect SamplerOutput imports
afeldman-nm Aug 22, 2024
102fd92
formatting
afeldman-nm Aug 22, 2024
948f4ef
Update tests/multi_step/test_correctness.py
afeldman-nm Aug 22, 2024
6e6711f
fixed comment
afeldman-nm Aug 22, 2024
f61163e
merge; format
afeldman-nm Aug 23, 2024
1cc93dd
rename
afeldman-nm Aug 23, 2024
4995204
Merge branch 'logprobs' into logprobs_merge
afeldman-nm Aug 23, 2024
da5826b
test modification
afeldman-nm Aug 26, 2024
d4fb430
merge; format
afeldman-nm Aug 26, 2024
b6752e0
merge
afeldman-nm Aug 27, 2024
1e42656
formatting
afeldman-nm Aug 27, 2024
cd0fdf9
disabled logprobs pythonization when logprobs are disabled
afeldman-nm Aug 27, 2024
3fecbc4
wip
afeldman-nm Aug 27, 2024
67bd035
skip logprobs processing entirely when logprobs are not enabled; form…
afeldman-nm Aug 27, 2024
419659d
multi-step output processing; formatting
afeldman-nm Aug 27, 2024
55eaab9
wip
afeldman-nm Aug 27, 2024
bae1fb9
small fixes
afeldman-nm Aug 27, 2024
fbb75b7
reverting to no prompt-logprobs support; merged in main
afeldman-nm Aug 28, 2024
63c5582
timeout increase
afeldman-nm Aug 28, 2024
8191571
refactoring
afeldman-nm Aug 28, 2024
9a708f8
Merge branch 'main' into logprobs_no_prompt
afeldman-nm Aug 28, 2024
e54606d
upstream merge
afeldman-nm Aug 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
disabled logprobs pythonization when logprobs are disabled
  • Loading branch information
afeldman-nm committed Aug 27, 2024
commit cd0fdf9e547e3a41076257c3f1d62ca1e7a3238a
42 changes: 31 additions & 11 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
SamplingMetadata, get_logprobs,
get_pythonized_sample_results)
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
SequenceGroupMetadata, SequenceOutput)
SequenceGroupMetadata, SequenceOutput, Logprob)
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
Expand Down Expand Up @@ -525,17 +525,24 @@ def _pythonize_sampler_output(
#
# However this computation may be skipped entirely
# if no pythonization was deferred.
seq_groups = sampling_metadata.seq_groups
logprobs_are_requested = any([
sg.sampling_params.logprobs is not None
or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
])
do_pythonize_logprobs = (skip_sampler_cpu_output
and logprobs_are_requested)
(
prompt_logprobs,
sample_logprobs,
) = (deferred_pythonize_logprobs(output, sampling_metadata,
logprobs_tensor)
if skip_sampler_cpu_output else (None, None))
if do_pythonize_logprobs else (None, None))

for sgdx, (seq_group, sample_result) in enumerate(
zip(sampling_metadata.seq_groups, samples_list)):
for sgdx, (seq_group,
sample_result) in enumerate(zip(seq_groups, samples_list)):

if skip_sampler_cpu_output:
if do_pythonize_logprobs:
assert prompt_logprobs is not None
assert sample_logprobs is not None

Expand All @@ -545,7 +552,12 @@ def _pythonize_sampler_output(
) = ( # Utilize deferred pythonization results
prompt_logprobs[sgdx],
sample_logprobs[sgdx],
) if skip_sampler_cpu_output else (
)
elif logprobs_are_requested:
(
group_prompt_logprobs,
group_sample_logprobs,
) = (
# profile_run: use already-computed logprobs
output.outputs[sgdx].prompt_logprobs,
[sample.logprobs for sample in output.outputs[sgdx].samples])
Expand All @@ -557,11 +569,19 @@ def _pythonize_sampler_output(
if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding")
for (parent_id, next_token_id,
logprobs) in zip(parent_ids, next_token_ids,
group_sample_logprobs):
for tdx, (parent_id,
next_token_id) in enumerate(zip(parent_ids, next_token_ids)):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
SequenceOutput(seq_ids[parent_id], next_token_id,
(group_sample_logprobs[tdx]
if logprobs_are_requested else {
next_token_id:
Logprob(logprob=float('inf'),
rank=None,
decoded_token=None)
})))
output.outputs.append(
CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))
CompletionSequenceGroupOutput(
seq_outputs,
(group_prompt_logprobs if logprobs_are_requested else None)))
assert len(output.outputs) > 0