Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two Step Edits #530

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Prev Previous commit
Next Next commit
two step stream_and_parse, initial version
  • Loading branch information
biobootloader committed Feb 23, 2024
commit 799911abe3e062c06c7c5675b80bcc404aee7949
16 changes: 13 additions & 3 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from mentat.parsers.file_edit import FileEdit
from mentat.parsers.parser import ParsedLLMResponse
from mentat.session_context import SESSION_CONTEXT
from mentat.stream_model_response import stream_model_response
from mentat.stream_model_response import (
get_two_step_system_prompt,
stream_model_response,
stream_model_response_two_step,
)
from mentat.transcripts import ModelMessage, TranscriptMessage, UserMessage


Expand Down Expand Up @@ -166,7 +170,10 @@ def get_messages(
return _messages
else:
parser = config.parser
prompt = parser.get_system_prompt()
if session_context.config.two_step_edits:
prompt = get_two_step_system_prompt()
else:
prompt = parser.get_system_prompt()
prompt_message: ChatCompletionMessageParam = (
ChatCompletionSystemMessageParam(
role="system",
Expand Down Expand Up @@ -208,7 +215,10 @@ async def get_model_response(self) -> ParsedLLMResponse:
)

try:
response = await stream_model_response(messages_snapshot)
if session_context.config.two_step_edits:
response = await stream_model_response_two_step(messages_snapshot)
else:
response = await stream_model_response(messages_snapshot)
except RateLimitError:
stream.send(
"Rate limit error received from OpenAI's servers using model"
Expand Down
7 changes: 4 additions & 3 deletions mentat/resources/prompts/two_step_edit_prompt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ In your response:
- don't rewrite whole files - just let the user know which edits to make
- make it clear where the edits should be, i.e. say you are rewriting a function or inserting it below another function
- unless the user asks for things to be commented out, don't suggest that. Just suggest deleting/removing unwanted code
- when specifying code to remove, try not to rewrite much code verbatim. Instead, suggest removing functions by name, using ellipses to indicate lines to remove, etc.
- try not to rewrite long sections of code that are mostly unchanged. Instead attempt to just write the changed lines, if possible to do so clearly.
- once you've specified the required edits, do not reiterate them and do not show how the code would look with them. Once is enough, our users value both brevity and clarity.
- when specifying code to remove, try not to rewrite much code verbatim. Instead, suggest removing functions by name, using ellipses to indicate lines to remove, etc
- try not to rewrite long sections of code that are mostly unchanged. Instead attempt to just write the changed lines, if possible to do so clearly
- our users value brevity. Once you've specified edits in enough detail that an intelligent actor could easily make them, do not elaborate or reiterate how they would look in the code
- after listing the edits, do not say something like "the code or file should now look like..."
147 changes: 136 additions & 11 deletions mentat/stream_model_response.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from __future__ import annotations

from openai.types.chat import ChatCompletionMessageParam
import asyncio
import logging
from pathlib import Path
from typing import AsyncIterator

from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam

from mentat.llm_api_handler import TOKEN_COUNT_WARNING, count_tokens, prompt_tokens
from mentat.parsers.parser import ParsedLLMResponse
from mentat.parsers.streaming_printer import FormattedString, StreamingPrinter
from mentat.prompts.prompts import read_prompt
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import add_newline

two_step_edit_prompt_filename = Path("two_step_edit_prompt.txt")


async def stream_model_response(
messages: list[ChatCompletionMessageParam],
Expand All @@ -19,6 +28,16 @@ async def stream_model_response(
llm_api_handler = session_context.llm_api_handler
cost_tracker = session_context.cost_tracker

num_prompt_tokens = prompt_tokens(messages, config.model)
stream.send(f"Total token count: {num_prompt_tokens}", style="info")
if num_prompt_tokens > TOKEN_COUNT_WARNING:
stream.send(
"Warning: LLM performance drops off rapidly at large context sizes. Use"
" /clear to clear context or use /exclude to exclude any uneccessary"
" files.",
style="warning",
)

stream.send(
None,
channel="loading",
Expand All @@ -35,16 +54,6 @@ async def stream_model_response(
terminate=True,
)

num_prompt_tokens = prompt_tokens(messages, config.model)
stream.send(f"Total token count: {num_prompt_tokens}", style="info")
if num_prompt_tokens > TOKEN_COUNT_WARNING:
stream.send(
"Warning: LLM performance drops off rapidly at large context sizes. Use"
" /clear to clear context or use /exclude to exclude any uneccessary"
" files.",
style="warning",
)

stream.send("Streaming... use control-c to interrupt the model at any point\n")
async with parser.interrupt_catcher():
parsed_llm_response = await parser.stream_and_parse_llm_response(
Expand All @@ -69,3 +78,119 @@ async def stream_model_response(
)

return parsed_llm_response


def get_two_step_system_prompt() -> str:
return read_prompt(two_step_edit_prompt_filename)


async def stream_model_response_two_step(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The stream_model_response_two_step function contains several TODO comments that suggest incomplete implementation, such as handling interrupts, tracking costs of all calls, and logging API calls. It's crucial to address these TODOs before merging this feature into the main branch to ensure the two-step edit process is fully functional and robust.

messages: list[ChatCompletionMessageParam],
) -> ParsedLLMResponse:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
code_file_manager = session_context.code_file_manager
config = session_context.config
parser = config.parser
llm_api_handler = session_context.llm_api_handler
cost_tracker = session_context.cost_tracker

num_prompt_tokens = prompt_tokens(messages, config.model)
stream.send(f"Total token count: {num_prompt_tokens}", style="info")
if num_prompt_tokens > TOKEN_COUNT_WARNING:
stream.send(
"Warning: LLM performance drops off rapidly at large context sizes. Use"
" /clear to clear context or use /exclude to exclude any uneccessary"
" files.",
style="warning",
)

stream.send(
None,
channel="loading",
)
response = await llm_api_handler.call_llm_api(
messages,
config.model,
stream=True,
response_format=parser.response_format(),
)
stream.send(
None,
channel="loading",
terminate=True,
)

# TODO: identify files mentioned, rewrite them with new calls
# TODO: add interrupt ability
# TODO: make sure to track costs of all calls and log api calls
stream.send("Streaming... use control-c to interrupt the model at any point\n")
first_message = await stream_and_parse_llm_response_two_step(response)

# async with parser.interrupt_catcher():
# parsed_llm_response = await parser.stream_and_parse_llm_response(
# add_newline(response)
# )

# # Sampler and History require previous_file_lines
# for file_edit in parsed_llm_response.file_edits:
# file_edit.previous_file_lines = code_file_manager.file_lines.get(
# file_edit.file_path, []
# )
# if not parsed_llm_response.interrupted:
# cost_tracker.display_last_api_call()
# else:
# # Generator doesn't log the api call if we interrupt it
# cost_tracker.log_api_call_stats(
# num_prompt_tokens,
# count_tokens(
# parsed_llm_response.full_response, config.model, full_message=False
# ),
# config.model,
# display=True,
# )

return ParsedLLMResponse(
full_response=first_message,
conversation=first_message,
# [file_edit for file_edit in file_edits.values()],
file_edits=[],
interrupted=False,
)


async def stream_and_parse_llm_response_two_step(
response: AsyncIterator[ChatCompletionChunk],
) -> str:
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
printer = StreamingPrinter()
printer_task = asyncio.create_task(printer.print_lines())

message = ""

async for chunk in response:
# if self.shutdown.is_set():
# interrupted = True
# printer.shutdown_printer()
# if printer_task is not None:
# await printer_task
# stream.send("\n\nInterrupted by user. Using the response up to this point.")
# break

content = ""
if len(chunk.choices) > 0:
content = chunk.choices[0].delta.content or ""

message += content
printer.add_string(content, end="")

# Only finish printing if we don't quit from ctrl-c
printer.wrap_it_up()
if printer_task is not None:
await printer_task

logging.debug("LLM Response:")
logging.debug(message)

return message
Loading