Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two Step Edits #530

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Prev Previous commit
Next Next commit
fix
  • Loading branch information
biobootloader committed Mar 6, 2024
commit adcb7e9538d9f9610d4ac6eaee788af9bd03e164
2 changes: 1 addition & 1 deletion mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def write_changes_to_files_two_step(
applied_edits: list[tuple[str, str]] = []
for abs_path, new_file_str in rewritten_files:
new_lines = new_file_str.splitlines()
self.write_to_file(abs_path, new_lines)
self.write_to_file(Path(abs_path), new_lines)
applied_edits.append((abs_path, new_file_str))

return applied_edits
Expand Down
3 changes: 2 additions & 1 deletion mentat/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ParsedLLMResponse:
full_response: str = attr.field()
conversation: str = attr.field()
file_edits: list[FileEdit] = attr.field()
rewritten_files: list[tuple[str, str]] = attr.field(default=[])
rewritten_files: list[tuple[str, str]] = attr.field(factory=list)
interrupted: bool = attr.field(default=False)


Expand Down Expand Up @@ -338,6 +338,7 @@ async def stream_and_parse_llm_response(
message,
conversation,
[file_edit for file_edit in file_edits.values()],
[],
interrupted,
)

Expand Down
62 changes: 36 additions & 26 deletions mentat/stream_model_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from difflib import ndiff
from json import JSONDecodeError
from pathlib import Path
from typing import AsyncIterator
from typing import Any, AsyncIterator

from openai.types.chat import (
ChatCompletionChunk,
Expand Down Expand Up @@ -104,7 +104,7 @@ def get_two_step_rewrite_file_prompt() -> str:
return read_prompt(two_step_edit_prompt_rewrite_file_filename)


def print_colored_diff(str1, str2, stream):
def print_colored_diff(str1: str, str2: str, stream: Any):
diff = ndiff(str1.splitlines(), str2.splitlines())

for line in diff:
Expand Down Expand Up @@ -185,19 +185,25 @@ async def stream_model_response_two_step(
response_format=ResponseFormat(type="json_object"),
)

try:
response_json = json.loads(list_files_response.choices[0].message.content)
except JSONDecodeError:
stream.send("Error processing model response: Invalid JSON", style="error")
# TODO: handle error
response_json = {}
if list_files_response.choices and list_files_response.choices[0].message.content:
try:
response_json: dict[str, list[str]] = json.loads(
list_files_response.choices[0].message.content
)
except JSONDecodeError:
stream.send("Error processing model response: Invalid JSON", style="error")
# TODO: handle error

stream.send(f"\n{response_json}\n")

# TODO remove line numbers when running two step edit
# TODO handle creating new files - including update prompt to know that's possible

rewritten_files = []
for file_path in response_json["files"]:
rewritten_files: list[tuple[str, str]] = []

file_paths: list[str] = response_json.get("files", [])
for file_path in file_paths:
full_path = (cwd / Path(file_path)).resolve()
code_file_lines = code_file_manager.file_lines.get(full_path, [])
code_file_string = "\n".join(code_file_lines)
Expand All @@ -222,21 +228,26 @@ async def stream_model_response_two_step(
model="gpt-3.5-turbo-0125", # TODO add config for secondary model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's important to handle potential assertion failures gracefully in production code. Consider adding error handling for cases where the expected rewrite_file_response format does not match the assumptions (e.g., missing ``` markers). This could include logging a warning and skipping the file or providing a clear error message to the user.

stream=False,
)
rewrite_file_response = rewrite_file_response.choices[0].message.content
lines = rewrite_file_response.splitlines()
# TODO remove asserts
assert "```" in lines[0]
assert "```" in lines[-1]
lines = lines[1:-1]
rewrite_file_response = "\n".join(lines)

rewritten_files.append((full_path, rewrite_file_response))

stream.send(f"\n### File Rewrite Response: {file_path} ###\n")
# stream.send(rewrite_file_response)

# TODO stream colored diff, skipping unchanged lines (except some for context)
print_colored_diff(code_file_string, rewrite_file_response, stream)
if (
rewrite_file_response
and rewrite_file_response.choices
and rewrite_file_response.choices[0].message.content
):
rewrite_file_response = rewrite_file_response.choices[0].message.content
lines = rewrite_file_response.splitlines()
# TODO remove asserts
assert "```" in lines[0]
assert "```" in lines[-1]
lines = lines[1:-1]
rewrite_file_response = "\n".join(lines)

rewritten_files.append((str(full_path), rewrite_file_response))

stream.send(f"\n### File Rewrite Response: {file_path} ###\n")
# stream.send(rewrite_file_response)

# TODO stream colored diff, skipping unchanged lines (except some for context)
print_colored_diff(code_file_string, rewrite_file_response, stream)

# async with parser.interrupt_catcher():
# parsed_llm_response = await parser.stream_and_parse_llm_response(
Expand Down Expand Up @@ -297,8 +308,7 @@ async def stream_and_parse_llm_response_two_step(

# Only finish printing if we don't quit from ctrl-c
printer.wrap_it_up()
if printer_task is not None:
await printer_task
await printer_task

logging.debug("LLM Response:")
logging.debug(message)
Expand Down
Loading