diff --git a/mentat/code_edit_feedback.py b/mentat/code_edit_feedback.py index 6c33d5777..1f53a5b3a 100644 --- a/mentat/code_edit_feedback.py +++ b/mentat/code_edit_feedback.py @@ -66,6 +66,50 @@ async def get_user_feedback_on_edits( return edits_to_apply, need_user_request +async def get_user_feedback_on_edits_two_step( + rewritten_files: list[tuple[str, str]], +) -> tuple[list[tuple[str, str]], bool]: + session_context = SESSION_CONTEXT.get() + stream = session_context.stream + conversation = session_context.conversation + + stream.send( + "Apply these changes? 'Y/n' or provide feedback.", + style="input", + ) + user_response_message = await collect_user_input() + user_response = user_response_message.data + + need_user_request = True + match user_response.lower(): + case "y" | "": + rewritten_files_to_apply = rewritten_files + conversation.add_message( + ChatCompletionSystemMessageParam( + role="system", content="User chose to apply all your changes." + ) + ) + case "n": + rewritten_files_to_apply = [] + conversation.add_message( + ChatCompletionSystemMessageParam( + role="system", + content="User chose not to apply any of your changes.", + ) + ) + case _: + need_user_request = False + rewritten_files_to_apply = [] + conversation.add_message( + ChatCompletionSystemMessageParam( + role="system", + content="User chose not to apply any of your changes.", + ) + ) + conversation.add_user_message(user_response) + return rewritten_files_to_apply, need_user_request + + async def _user_filter_changes(file_edits: list[FileEdit]) -> list[FileEdit]: new_edits = list[FileEdit]() for file_edit in file_edits: diff --git a/mentat/code_file_manager.py b/mentat/code_file_manager.py index e8ddef059..282812a85 100644 --- a/mentat/code_file_manager.py +++ b/mentat/code_file_manager.py @@ -149,6 +149,18 @@ async def write_changes_to_files( self.history.push_edits() return applied_edits + # TODO handle creation, deletion, rename, undo/redo, check if file was modified, etc. + async def write_changes_to_files_two_step( + self, rewritten_files: list[tuple[str, str]] + ) -> list[tuple[str, str]]: + applied_edits: list[tuple[str, str]] = [] + for abs_path, new_file_str in rewritten_files: + new_lines = new_file_str.splitlines() + self.write_to_file(Path(abs_path), new_lines) + applied_edits.append((abs_path, new_file_str)) + + return applied_edits + def get_file_checksum(self, path: Path, interval: Interval | None = None) -> str: if path.is_dir(): return "" # TODO: Build and maintain a hash tree for git_root diff --git a/mentat/config.py b/mentat/config.py index 4dbc19bdd..b21581f84 100644 --- a/mentat/config.py +++ b/mentat/config.py @@ -98,6 +98,16 @@ class Config: }, converter=converters.optional(converters.to_bool), ) + two_step_edits: bool = attr.field( + default=False, + metadata={ + "description": ( + "Experimental feature that uses multiple LLM calls to make and parse" + " edits" + ), + "auto_completions": bool_autocomplete, + }, + ) revisor: bool = attr.field( default=False, metadata={ diff --git a/mentat/conversation.py b/mentat/conversation.py index f551cc927..f4a786b4a 100644 --- a/mentat/conversation.py +++ b/mentat/conversation.py @@ -16,7 +16,6 @@ ) from mentat.llm_api_handler import ( - TOKEN_COUNT_WARNING, count_tokens, get_max_tokens, prompt_tokens, @@ -25,8 +24,12 @@ from mentat.parsers.file_edit import FileEdit from mentat.parsers.parser import ParsedLLMResponse from mentat.session_context import SESSION_CONTEXT +from mentat.stream_model_response import ( + get_two_step_system_prompt, + stream_model_response, + stream_model_response_two_step, +) from mentat.transcripts import ModelMessage, TranscriptMessage, UserMessage -from mentat.utils import add_newline class MentatAssistantMessageParam(ChatCompletionAssistantMessageParam): @@ -181,13 +184,21 @@ async def get_messages( if ctx.config.no_parser_prompt: system_prompt = [] else: - parser = ctx.config.parser - system_prompt = [ - ChatCompletionSystemMessageParam( - role="system", - content=parser.get_system_prompt(), - ) - ] + if ctx.config.two_step_edits: + system_prompt = [ + ChatCompletionSystemMessageParam( + role="system", + content=get_two_step_system_prompt(), + ) + ] + else: + parser = ctx.config.parser + system_prompt = [ + ChatCompletionSystemMessageParam( + role="system", + content=parser.get_system_prompt(), + ) + ] return system_prompt + _messages @@ -195,78 +206,6 @@ def clear_messages(self) -> None: """Clears the messages in the conversation""" self._messages = list[ChatCompletionMessageParam]() - async def _stream_model_response( - self, - messages: list[ChatCompletionMessageParam], - ) -> ParsedLLMResponse: - session_context = SESSION_CONTEXT.get() - stream = session_context.stream - code_file_manager = session_context.code_file_manager - config = session_context.config - parser = config.parser - llm_api_handler = session_context.llm_api_handler - cost_tracker = session_context.cost_tracker - - stream.send( - None, - channel="loading", - ) - response = await llm_api_handler.call_llm_api( - messages, - config.model, - stream=True, - response_format=parser.response_format(), - ) - stream.send( - None, - channel="loading", - terminate=True, - ) - - num_prompt_tokens = prompt_tokens(messages, config.model) - stream.send(f"Total token count: {num_prompt_tokens}", style="info") - if num_prompt_tokens > TOKEN_COUNT_WARNING: - stream.send( - "Warning: LLM performance drops off rapidly at large context sizes. Use" - " /clear to clear context or use /exclude to exclude any uneccessary" - " files.", - style="warning", - ) - - stream.send("Streaming... use control-c to interrupt the model at any point\n") - async with parser.interrupt_catcher(): - parsed_llm_response = await parser.stream_and_parse_llm_response( - add_newline(response) - ) - # Sampler and History require previous_file_lines - for file_edit in parsed_llm_response.file_edits: - file_edit.previous_file_lines = code_file_manager.file_lines.get( - file_edit.file_path, [] - ) - if not parsed_llm_response.interrupted: - cost_tracker.display_last_api_call() - else: - # Generator doesn't log the api call if we interrupt it - cost_tracker.log_api_call_stats( - num_prompt_tokens, - count_tokens( - parsed_llm_response.full_response, config.model, full_message=False - ), - config.model, - display=True, - ) - - messages.append( - ChatCompletionAssistantMessageParam( - role="assistant", content=parsed_llm_response.full_response - ) - ) - self.add_model_message( - parsed_llm_response.full_response, messages, parsed_llm_response - ) - - return parsed_llm_response - async def get_model_response(self) -> ParsedLLMResponse: session_context = SESSION_CONTEXT.get() stream = session_context.stream @@ -277,7 +216,10 @@ async def get_model_response(self) -> ParsedLLMResponse: raise_if_context_exceeds_max(tokens_used) try: - response = await self._stream_model_response(messages_snapshot) + if session_context.config.two_step_edits: + response = await stream_model_response_two_step(messages_snapshot) + else: + response = await stream_model_response(messages_snapshot) except RateLimitError: stream.send( "Rate limit error received from OpenAI's servers using model" @@ -286,6 +228,14 @@ async def get_model_response(self) -> ParsedLLMResponse: style="error", ) return ParsedLLMResponse("", "", list[FileEdit]()) + + messages_snapshot.append( + ChatCompletionAssistantMessageParam( + role="assistant", content=response.full_response + ) + ) + self.add_model_message(response.full_response, messages_snapshot, response) + return response async def remaining_context(self) -> int | None: diff --git a/mentat/parsers/git_parser.py b/mentat/parsers/git_parser.py index da25bbb25..7df10ccfd 100644 --- a/mentat/parsers/git_parser.py +++ b/mentat/parsers/git_parser.py @@ -143,7 +143,9 @@ def parse_llm_response(self, content: str) -> ParsedLLMResponse: file_edits.append(file_edit) return ParsedLLMResponse( - f"{conversation}\n\n{git_diff}", conversation, file_edits + f"{conversation}\n\n{git_diff}", + conversation, + file_edits, ) def file_edit_to_git_diff(self, file_edit: FileEdit) -> str: diff --git a/mentat/parsers/parser.py b/mentat/parsers/parser.py index 32e740ea0..38c4f019b 100644 --- a/mentat/parsers/parser.py +++ b/mentat/parsers/parser.py @@ -35,6 +35,7 @@ class ParsedLLMResponse: full_response: str = attr.field() conversation: str = attr.field() file_edits: list[FileEdit] = attr.field() + rewritten_files: list[tuple[str, str]] = attr.field(factory=list) interrupted: bool = attr.field(default=False) @@ -337,6 +338,7 @@ async def stream_and_parse_llm_response( message, conversation, [file_edit for file_edit in file_edits.values()], + [], interrupted, ) diff --git a/mentat/resources/prompts/two_step_edit_prompt.txt b/mentat/resources/prompts/two_step_edit_prompt.txt new file mode 100644 index 000000000..9ffb16951 --- /dev/null +++ b/mentat/resources/prompts/two_step_edit_prompt.txt @@ -0,0 +1,14 @@ +**You are now operating within an advanced AI coding system designed to assist with code modifications and enhancements.** + +Upon receiving context, which may range from specific code snippets to entire repositories, you will be tasked with addressing coding requests or answering questions. + +**For your responses:** + +- **Directly address the request or question:** Provide concise instructions for any code modifications, clearly stating what changes need to be made. +- **Specify modifications without reiterating existing code:** Guide the user on where and how to make modifications, e.g., "insert the new code block above the last function in the file" or "replace the existing loop condition with the provided snippet." Ensure instructions are clear without displaying how the entire file looks post-modification. +- **Use the full file path at least once per file with edits:** When mentioning a file for the first time, use its full path. You can refer to it by a shorter name afterward if it remains clear which file you're discussing. +- **Avoid suggesting non-actionable edits:** Do not recommend commenting out or non-specific removals. Be explicit about what to delete or change, referring to code blocks or functions by name and avoiding extensive verbatim rewrites. +- **Minimize the inclusion of unchanged code:** Focus on the new or altered lines rather than embedding them within large blocks of unchanged code. Your guidance should be clear enough for an intelligent actor to implement with just the changes specified. +- **Emphasize brevity and clarity:** Once you've provided detailed instructions for the edits, there's no need for further elaboration. Avoid concluding with summaries of how the code will look after the edits. + +**Your guidance should empower users to confidently implement the suggested changes with minimal and precise directions, fostering an efficient and clear modification process.** diff --git a/mentat/resources/prompts/two_step_edit_prompt_list_files.txt b/mentat/resources/prompts/two_step_edit_prompt_list_files.txt new file mode 100644 index 000000000..1f6fd1acc --- /dev/null +++ b/mentat/resources/prompts/two_step_edit_prompt_list_files.txt @@ -0,0 +1,10 @@ +You are part of an expert AI coding system. + +The next message will be an answer to a user's question or request. It may include suggested edits to code files. Your job is simply to extract the names of files that edits need to be made to, according to that message. + +In your response: + - respond in json, with a single key "files" and a value that is an array of strings + - return empty array if no files have suggested edits, e.g. {"files":[]} + - the message may mention files without suggesting edits to them, do not include these. Only include files that have suggested edits + - if a file is meant to be created, include it in the list of files to edit + diff --git a/mentat/resources/prompts/two_step_edit_prompt_rewrite_file.txt b/mentat/resources/prompts/two_step_edit_prompt_rewrite_file.txt new file mode 100644 index 000000000..73d1ae880 --- /dev/null +++ b/mentat/resources/prompts/two_step_edit_prompt_rewrite_file.txt @@ -0,0 +1,11 @@ +You are part of an expert AI coding system. + +In the next message you will be given the contents of a code file. The user will then specify some edits to be made to the file. + +Your response should: + - rewrite the entire file, including all the requested edits + - wrap your entire response in ``` + - do not include anything else in your response other than the code + - do not make any other changes to the code other than the requested edits, even to standardize formatting + - even formatting changes should not be made unless explicitly requested by the user + - if a change is not fully specified, do your best to follow the spirit of what was asked diff --git a/mentat/session.py b/mentat/session.py index 813e29c42..d0e6135e5 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -19,7 +19,10 @@ from mentat.agent_handler import AgentHandler from mentat.auto_completer import AutoCompleter from mentat.code_context import CodeContext -from mentat.code_edit_feedback import get_user_feedback_on_edits +from mentat.code_edit_feedback import ( + get_user_feedback_on_edits, + get_user_feedback_on_edits_two_step, +) from mentat.code_file_manager import CodeFileManager from mentat.config import Config from mentat.conversation import Conversation @@ -170,41 +173,75 @@ async def _main(self): conversation.add_user_message(message.data) parsed_llm_response = await conversation.get_model_response() - file_edits = [ - file_edit - for file_edit in parsed_llm_response.file_edits - if file_edit.is_valid() - ] - for file_edit in file_edits: - file_edit.resolve_conflicts() - if file_edits: - if session_context.config.revisor: - await revise_edits(file_edits) - - if not agent_handler.agent_enabled: - file_edits, need_user_request = ( - await get_user_feedback_on_edits(file_edits) + if not session_context.config.two_step_edits: + file_edits = [ + file_edit + for file_edit in parsed_llm_response.file_edits + if file_edit.is_valid() + ] + for file_edit in file_edits: + file_edit.resolve_conflicts() + if file_edits: + if session_context.config.revisor: + await revise_edits(file_edits) + + if not agent_handler.agent_enabled: + file_edits, need_user_request = ( + await get_user_feedback_on_edits(file_edits) + ) + + if session_context.config.sampler: + session_context.sampler.set_active_diff() + + applied_edits = await code_file_manager.write_changes_to_files( + file_edits + ) + stream.send( + ( + "Changes applied." + if applied_edits + else "No changes applied." + ), + style="input", ) - if session_context.config.sampler: - session_context.sampler.set_active_diff() - - applied_edits = await code_file_manager.write_changes_to_files( - file_edits - ) - stream.send( - "Changes applied." if applied_edits else "No changes applied.", - style="input", - ) - - if agent_handler.agent_enabled: - if parsed_llm_response.interrupted: - need_user_request = True - else: - need_user_request = await agent_handler.add_agent_context() + if agent_handler.agent_enabled: + if parsed_llm_response.interrupted: + need_user_request = True + else: + need_user_request = ( + await agent_handler.add_agent_context() + ) + else: + need_user_request = True + stream.send(bool(file_edits), channel="edits_complete") else: - need_user_request = True - stream.send(bool(file_edits), channel="edits_complete") + rewritten_files = parsed_llm_response.rewritten_files + if rewritten_files: + if not agent_handler.agent_enabled: + rewritten_files, need_user_request = ( + await get_user_feedback_on_edits_two_step( + rewritten_files + ) + ) + + if session_context.config.sampler: + session_context.sampler.set_active_diff() + + applied_rewritten_files = ( + await code_file_manager.write_changes_to_files_two_step( + rewritten_files + ) + ) + stream.send( + ( + "Changes applied." + if applied_rewritten_files + else "No changes applied." + ), + style="input", + ) + except SessionExit: stream.send(None, channel="client_exit") break diff --git a/mentat/stream_model_response.py b/mentat/stream_model_response.py new file mode 100644 index 000000000..bbb87e239 --- /dev/null +++ b/mentat/stream_model_response.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import asyncio +import json +import logging +from difflib import ndiff +from json import JSONDecodeError +from pathlib import Path +from typing import Any, AsyncIterator + +from openai.types.chat import ( + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, +) +from openai.types.chat.completion_create_params import ResponseFormat + +from mentat.llm_api_handler import TOKEN_COUNT_WARNING, count_tokens, prompt_tokens +from mentat.parsers.parser import ParsedLLMResponse +from mentat.parsers.streaming_printer import StreamingPrinter +from mentat.prompts.prompts import read_prompt +from mentat.session_context import SESSION_CONTEXT +from mentat.utils import add_newline + +two_step_edit_prompt_filename = Path("two_step_edit_prompt.txt") +two_step_edit_prompt_list_files_filename = Path("two_step_edit_prompt_list_files.txt") +two_step_edit_prompt_rewrite_file_filename = Path( + "two_step_edit_prompt_rewrite_file.txt" +) + + +async def stream_model_response( + messages: list[ChatCompletionMessageParam], +) -> ParsedLLMResponse: + session_context = SESSION_CONTEXT.get() + stream = session_context.stream + code_file_manager = session_context.code_file_manager + config = session_context.config + parser = config.parser + llm_api_handler = session_context.llm_api_handler + cost_tracker = session_context.cost_tracker + + num_prompt_tokens = prompt_tokens(messages, config.model) + stream.send(f"Total token count: {num_prompt_tokens}", style="info") + if num_prompt_tokens > TOKEN_COUNT_WARNING: + stream.send( + "Warning: LLM performance drops off rapidly at large context sizes. Use" + " /clear to clear context or use /exclude to exclude any uneccessary" + " files.", + style="warning", + ) + + stream.send( + None, + channel="loading", + ) + response = await llm_api_handler.call_llm_api( + messages, + config.model, + stream=True, + response_format=parser.response_format(), + ) + stream.send( + None, + channel="loading", + terminate=True, + ) + + stream.send("Streaming... use control-c to interrupt the model at any point\n") + async with parser.interrupt_catcher(): + parsed_llm_response = await parser.stream_and_parse_llm_response( + add_newline(response) + ) + # Sampler and History require previous_file_lines + for file_edit in parsed_llm_response.file_edits: + file_edit.previous_file_lines = code_file_manager.file_lines.get( + file_edit.file_path, [] + ) + if not parsed_llm_response.interrupted: + cost_tracker.display_last_api_call() + else: + # Generator doesn't log the api call if we interrupt it + cost_tracker.log_api_call_stats( + num_prompt_tokens, + count_tokens( + parsed_llm_response.full_response, config.model, full_message=False + ), + config.model, + display=True, + ) + + return parsed_llm_response + + +def get_two_step_system_prompt() -> str: + return read_prompt(two_step_edit_prompt_filename) + + +def get_two_step_list_files_prompt() -> str: + return read_prompt(two_step_edit_prompt_list_files_filename) + + +def get_two_step_rewrite_file_prompt() -> str: + return read_prompt(two_step_edit_prompt_rewrite_file_filename) + + +def print_colored_diff(str1: str, str2: str, stream: Any): + diff = ndiff(str1.splitlines(), str2.splitlines()) + + for line in diff: + if line.startswith("-"): + stream.send(line, color="red") + elif line.startswith("+"): + stream.send(line, color="green") + elif line.startswith("?"): + pass # skip printing the ? lines ndiff produces + else: + stream.send(line) + + +async def stream_model_response_two_step( + messages: list[ChatCompletionMessageParam], +) -> ParsedLLMResponse: + session_context = SESSION_CONTEXT.get() + stream = session_context.stream + code_file_manager = session_context.code_file_manager + config = session_context.config + parser = config.parser + llm_api_handler = session_context.llm_api_handler + cwd = session_context.cwd + + num_prompt_tokens = prompt_tokens(messages, config.model) + stream.send(f"Total token count: {num_prompt_tokens}", style="info") + if num_prompt_tokens > TOKEN_COUNT_WARNING: + stream.send( + "Warning: LLM performance drops off rapidly at large context sizes. Use" + " /clear to clear context or use /exclude to exclude any uneccessary" + " files.", + style="warning", + ) + + stream.send( + None, + channel="loading", + ) + response = await llm_api_handler.call_llm_api( + messages, + config.model, + stream=True, + response_format=parser.response_format(), + ) + stream.send( + None, + channel="loading", + terminate=True, + ) + + # TODO: if using two step, don't add line numbers to context - might help + # TODO: identify files mentioned, rewrite them with new calls + # TODO: instead of FileEdit objects, return new rewritten files? + # TODO: add interrupt ability + # TODO: make sure to track costs of all calls and log api calls + stream.send("Streaming... use control-c to interrupt the model at any point\n") + first_message = await stream_and_parse_llm_response_two_step(response) + + stream.send( + "\n\n### Initial Response Complete - parsing edits: ###\n", style="info" + ) + + list_files_messages: list[ChatCompletionMessageParam] = [ + ChatCompletionSystemMessageParam( + role="system", + content=get_two_step_list_files_prompt(), + ), + ChatCompletionSystemMessageParam( + role="system", + content=first_message, + ), + ] + + list_files_response = await llm_api_handler.call_llm_api( + list_files_messages, + model="gpt-3.5-turbo-0125", # TODO add config for secondary model + stream=False, + response_format=ResponseFormat(type="json_object"), + ) + + response_json = {} + if list_files_response.choices and list_files_response.choices[0].message.content: + try: + response_json: dict[str, list[str]] = json.loads( + list_files_response.choices[0].message.content + ) + except JSONDecodeError: + stream.send("Error processing model response: Invalid JSON", style="error") + # TODO: handle error + + stream.send(f"\n{response_json}\n") + + # TODO remove line numbers when running two step edit + # TODO handle creating new files - including update prompt to know that's possible + + rewritten_files: list[tuple[str, str]] = [] + + file_paths: list[str] = response_json.get("files", []) + for file_path in file_paths: + full_path = (cwd / Path(file_path)).resolve() + code_file_lines = code_file_manager.file_lines.get(full_path, []) + code_file_string = "\n".join(code_file_lines) + + rewrite_file_messages: list[ChatCompletionMessageParam] = [ + ChatCompletionSystemMessageParam( + role="system", + content=get_two_step_rewrite_file_prompt(), + ), + ChatCompletionSystemMessageParam( + role="system", + content=code_file_string, + ), + ChatCompletionSystemMessageParam( + role="system", # TODO: change to user? not sure + content=first_message, + ), + ] + + rewrite_file_response = await llm_api_handler.call_llm_api( + rewrite_file_messages, + model="gpt-3.5-turbo-0125", # TODO add config for secondary model + stream=False, + ) + if ( + rewrite_file_response + and rewrite_file_response.choices + and rewrite_file_response.choices[0].message.content + ): + rewrite_file_response = rewrite_file_response.choices[0].message.content + lines = rewrite_file_response.splitlines() + # TODO remove asserts + assert "```" in lines[0] + assert "```" in lines[-1] + lines = lines[1:-1] + rewrite_file_response = "\n".join(lines) + + rewritten_files.append((str(full_path), rewrite_file_response)) + + stream.send(f"\n### File Rewrite Response: {file_path} ###\n") + # stream.send(rewrite_file_response) + + # TODO stream colored diff, skipping unchanged lines (except some for context) + print_colored_diff(code_file_string, rewrite_file_response, stream) + + # async with parser.interrupt_catcher(): + # parsed_llm_response = await parser.stream_and_parse_llm_response( + # add_newline(response) + # ) + + # # Sampler and History require previous_file_lines + # for file_edit in parsed_llm_response.file_edits: + # file_edit.previous_file_lines = code_file_manager.file_lines.get( + # file_edit.file_path, [] + # ) + # if not parsed_llm_response.interrupted: + # cost_tracker.display_last_api_call() + # else: + # # Generator doesn't log the api call if we interrupt it + # cost_tracker.log_api_call_stats( + # num_prompt_tokens, + # count_tokens( + # parsed_llm_response.full_response, config.model, full_message=False + # ), + # config.model, + # display=True, + # ) + + return ParsedLLMResponse( + full_response=first_message, + conversation=first_message, + # [file_edit for file_edit in file_edits.values()], + file_edits=[], + rewritten_files=rewritten_files, + interrupted=False, + ) + + +async def stream_and_parse_llm_response_two_step( + response: AsyncIterator[ChatCompletionChunk], +) -> str: + printer = StreamingPrinter() + printer_task = asyncio.create_task(printer.print_lines()) + + message = "" + + async for chunk in response: + # if self.shutdown.is_set(): + # interrupted = True + # printer.shutdown_printer() + # if printer_task is not None: + # await printer_task + # stream.send("\n\nInterrupted by user. Using the response up to this point.") + # break + + content = "" + if len(chunk.choices) > 0: + content = chunk.choices[0].delta.content or "" + + message += content + printer.add_string(content, end="") + + # Only finish printing if we don't quit from ctrl-c + printer.wrap_it_up() + await printer_task + + logging.debug("LLM Response:") + logging.debug(message) + + return message