Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beam search seems not work as expected #200

Closed
goodbai-nlp opened this issue Jun 22, 2023 · 5 comments · Fixed by #202
Closed

Beam search seems not work as expected #200

goodbai-nlp opened this issue Jun 22, 2023 · 5 comments · Fixed by #202
Labels
bug Something isn't working

Comments

@goodbai-nlp
Copy link

Hi,

I try to use the beam search of vllm with the following code, but I found it processes and generates much more results than input, could anyone kindly tell me how to fix that?

model = LLM(model=args.model_name_or_path)
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    gen_params = SamplingParams(n=1, use_beam_search=True, best_of=5, temperature=0, top_p=1, top_k=-1, max_tokens=args.max_new_tokens, stop=["</s>"])
    outputs = model.generate(prompts, gen_params)
    print(f"generate {len(outputs)} outputs!!!")
image
@WoosukKwon
Copy link
Collaborator

Hi @goodbai-nlp thanks for trying out vLLM! According to your output, it seems you put 173 prompts as the input to LLM. Could you check again?

@goodbai-nlp
Copy link
Author

@WoosukKwon yeah that's weird, I am sure that only put 4 prompts are taken as input (and I print the request number to confirm that). Could you reproduce this bug?

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Jun 22, 2023

@goodbai-nlp Could you try running https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py and let me know the output?

When I ran the attached code, I got 4 outputs.

@goodbai-nlp
Copy link
Author

@WoosukKwon, I got 4 outputs from the attached code. However, when I try to use beam search and set a larger value for max_tokens, the bug occurs.
image
The output when max_tokens=64:
image
The output when max_tokens=256:
image
In my opinion, max_tokens should only affect the max number of generated token.

@WoosukKwon WoosukKwon added the bug Something isn't working label Jun 22, 2023
@WoosukKwon
Copy link
Collaborator

@goodbai-nlp I reproduced the bug and fixed it in #202. Thanks for the bug report!

jikunshang pushed a commit to jikunshang/vllm that referenced this issue Sep 24, 2024
Add Dockerfile.hpu

FIX HabanaAI#199

**BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE
DESCRIPTION ABOVE**

---

<details>
<!-- inside this <details> section, markdown rendering does not work, so
we use raw html here. -->
<summary><b> PR Checklist (Click to Expand) </b></summary>

<p>Thank you for your contribution to vLLM! Before submitting the pull
request, please ensure the PR meets the following criteria. This helps
vLLM maintain the code quality and improve the efficiency of the review
process.</p>

<h3>PR Title and Classification</h3>
<p>Only specific types of PRs will be reviewed. The PR title is prefixed
appropriately to indicate the type of change. Please use one of the
following:</p>
<ul>
    <li><code>[Bugfix]</code> for bug fixes.</li>
<li><code>[CI/Build]</code> for build or continuous integration
improvements.</li>
<li><code>[Doc]</code> for documentation fixes and improvements.</li>
<li><code>[Model]</code> for adding a new model or improving an existing
model. Model name should appear in the title.</li>
<li><code>[Frontend]</code> For changes on the vLLM frontend (e.g.,
OpenAI API server, <code>LLM</code> class, etc.) </li>
<li><code>[Kernel]</code> for changes affecting CUDA kernels or other
compute kernels.</li>
<li><code>[Core]</code> for changes in the core vLLM logic (e.g.,
<code>LLMEngine</code>, <code>AsyncLLMEngine</code>,
<code>Scheduler</code>, etc.)</li>
<li><code>[Hardware][Vendor]</code> for hardware-specific changes.
Vendor name should appear in the prefix (e.g.,
<code>[Hardware][AMD]</code>).</li>
<li><code>[Misc]</code> for PRs that do not fit the above categories.
Please use this sparingly.</li>
</ul>
<p><strong>Note:</strong> If the PR spans more than one category, please
include all relevant prefixes.</p>

<h3>Code Quality</h3>

<p>The PR need to meet the following code quality standards:</p>

<ul>
<li>We adhere to <a
href="https://google.github.io/styleguide/pyguide.html">Google Python
style guide</a> and <a
href="https://google.github.io/styleguide/cppguide.html">Google C++
style guide</a>.</li>
<li>Pass all linter checks. Please use <a
href="https://github.com/vllm-project/vllm/blob/main/format.sh"><code>format.sh</code></a>
to format your code.</li>
<li>The code need to be well-documented to ensure future contributors
can easily understand the code.</li>
<li>Include sufficient tests to ensure the project to stay correct and
robust. This includes both unit tests and integration tests.</li>
<li>Please add documentation to <code>docs/source/</code> if the PR
modifies the user-facing behaviors of vLLM. It helps vLLM user
understand and utilize the new features or changes.</li>
</ul>

<h3>Notes for Large Changes</h3>
<p>Please keep the changes as concise as possible. For major
architectural changes (>500 LOC excluding kernel/data/config/test), we
would expect a GitHub issue (RFC) discussing the technical design and
justification. Otherwise, we will tag it with <code>rfc-required</code>
and might not go through the PR.</p>

<h3>What to Expect for the Reviews</h3>

<p>The goal of the vLLM team is to be a <i>transparent reviewing
machine</i>. We would like to make the review process transparent and
efficient and make sure no contributor feel confused or frustrated.
However, the vLLM team is small, so we need to prioritize some PRs over
others. Here is what you can expect from the review process: </p>

<ul>
<li> After the PR is submitted, the PR will be assigned to a reviewer.
Every reviewer will pick up the PRs based on their expertise and
availability.</li>
<li> After the PR is assigned, the reviewer will provide status update
every 2-3 days. If the PR is not reviewed within 7 days, please feel
free to ping the reviewer or the vLLM team.</li>
<li> After the review, the reviewer will put an <code>
action-required</code> label on the PR if there are changes required.
The contributor should address the comments and ping the reviewer to
re-review the PR.</li>
<li> Please respond to all comments within a reasonable time frame. If a
comment isn't clear or you disagree with a suggestion, feel free to ask
for clarification or discuss the suggestion.
 </li>
</ul>

<h3>Thank You</h3>

<p> Finally, thank you for taking the time to read these guidelines and
for your interest in contributing to vLLM. Your contributions make vLLM
a great tool for everyone! </p>


</details>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants