"""A small generalization of the default agent that puts the user in the loop. There are three modes: - human: commands issued by the user are executed immediately - confirm: commands issued by the LM but not whitelisted are confirmed by the user - yolo: commands issued by the LM are executed immediately without confirmation """ import re from typing import Literal, NoReturn from prompt_toolkit.formatted_text import HTML from prompt_toolkit.history import FileHistory from prompt_toolkit.shortcuts import PromptSession from rich.console import Console from rich.rule import Rule from minisweagent import global_config_dir from minisweagent.agents.default import AgentConfig, DefaultAgent from minisweagent.exceptions import LimitsExceeded, Submitted, UserInterruption from minisweagent.models.utils.content_string import get_content_string console = Console(highlight=False) _history = FileHistory(global_config_dir / "interactive_history.txt") _prompt_session = PromptSession(history=_history) _multiline_prompt_session = PromptSession(history=_history, multiline=True) class InteractiveAgentConfig(AgentConfig): mode: Literal["human", "confirm", "yolo"] = "confirm" """Whether to confirm actions.""" whitelist_actions: list[str] = [] """Never confirm actions that match these regular expressions.""" confirm_exit: bool = True """If the agent wants to finish, do we ask for confirmation from user?""" def _multiline_prompt() -> str: return _multiline_prompt_session.prompt( "", bottom_toolbar=HTML( "Submit message: Esc, then Enter | " "Navigate history: Arrow Up/Down | " "Search history: Ctrl+R" ), ) class InteractiveAgent(DefaultAgent): _MODE_COMMANDS_MAPPING = {"/u": "human", "/c": "confirm", "/y": "yolo"} def __init__(self, *args, config_class=InteractiveAgentConfig, **kwargs): super().__init__(*args, config_class=config_class, **kwargs) self.cost_last_confirmed = 0.0 def add_messages(self, *messages: dict) -> list[dict]: # Extend supermethod to print messages for msg in messages: role, content = msg.get("role") or msg.get("type", "unknown"), get_content_string(msg) if role == "assistant": console.print( f"\n[red][bold]mini-swe-agent[/bold] (step [bold]{self.n_calls}[/bold], [bold]${self.cost:.2f}[/bold]):[/red]\n", end="", highlight=False, ) else: console.print(f"\n[bold green]{role.capitalize()}[/bold green]:\n", end="", highlight=False) console.print(content, highlight=False, markup=False) return super().add_messages(*messages) def query(self) -> dict: # Extend supermethod to handle human mode if self.config.mode == "human": match command := self._prompt_and_handle_slash_commands("[bold yellow]>[/bold yellow] "): case "/y" | "/c": pass case _: msg = { "role": "user", "content": f"User command: \n```bash\n{command}\n```", "extra": {"actions": [{"command": command}]}, } self.add_messages(msg) return msg try: with console.status("Waiting for the LM to respond..."): return super().query() except LimitsExceeded: console.print( f"Limits exceeded. Limits: {self.config.step_limit} steps, ${self.config.cost_limit}.\n" f"Current spend: {self.n_calls} steps, ${self.cost:.2f}." ) self.config.step_limit = int(input("New step limit: ")) self.config.cost_limit = float(input("New cost limit: ")) return super().query() def step(self) -> list[dict]: # Override the step method to handle user interruption try: console.print(Rule()) return super().step() except KeyboardInterrupt: interruption_message = self._prompt_and_handle_slash_commands( "\n\n[bold yellow]Interrupted.[/bold yellow] " "[green]Type a comment/command[/green] (/h for available commands)" "\n[bold yellow]>[/bold yellow] " ).strip() if not interruption_message or interruption_message in self._MODE_COMMANDS_MAPPING: interruption_message = "Temporary interruption caught." raise UserInterruption( { "role": "user", "content": f"Interrupted by user: {interruption_message}", "extra": {"interrupt_type": "UserInterruption"}, } ) def execute_actions(self, message: dict) -> list[dict]: # Override to handle user confirmation and confirm_exit, with try/finally to preserve partial outputs actions = message.get("extra", {}).get("actions", []) commands = [action["command"] for action in actions] outputs = [] try: self._ask_confirmation_or_interrupt(commands) for action in actions: outputs.append(self.env.execute(action)) except Submitted as e: self._check_for_new_task_or_submit(e) finally: result = self.add_messages( *self.model.format_observation_messages(message, outputs, self.get_template_vars()) ) return result def _add_observation_messages(self, message: dict, outputs: list[dict]) -> list[dict]: return self.add_messages(*self.model.format_observation_messages(message, outputs, self.get_template_vars())) def _check_for_new_task_or_submit(self, e: Submitted) -> NoReturn: """Check if user wants to add a new task or submit.""" if self.config.confirm_exit: message = ( "[bold yellow]Agent wants to finish.[/bold yellow] " "[bold green]Type new task[/bold green] or [red][bold]Esc, then enter[/bold] to quit.[/red]\n" "[bold yellow]>[/bold yellow] " ) if new_task := self._prompt_and_handle_slash_commands(message, _multiline=True).strip(): raise UserInterruption( { "role": "user", "content": f"The user added a new task: {new_task}", "extra": {"interrupt_type": "UserNewTask"}, } ) raise e def _should_ask_confirmation(self, action: str) -> bool: return self.config.mode == "confirm" and not any(re.match(r, action) for r in self.config.whitelist_actions) def _ask_confirmation_or_interrupt(self, commands: list[str]) -> None: commands_needing_confirmation = [c for c in commands if self._should_ask_confirmation(c)] if not commands_needing_confirmation: return n = len(commands_needing_confirmation) prompt = ( f"[bold yellow]Execute {n} action(s)?[/] [green][bold]Enter[/] to confirm[/], " "[red]type [bold]comment[/] to reject[/], or [blue][bold]/h[/] to show available commands[/]\n" "[bold yellow]>[/bold yellow] " ) match user_input := self._prompt_and_handle_slash_commands(prompt).strip(): case "" | "/y": pass # confirmed, do nothing case "/u": # Skip execution action and get back to query raise UserInterruption( { "role": "user", "content": "Commands not executed. Switching to human mode", "extra": {"interrupt_type": "UserRejection"}, } ) case _: raise UserInterruption( { "role": "user", "content": f"Commands not executed. The user rejected your commands with the following message: {user_input}", "extra": {"interrupt_type": "UserRejection"}, } ) def _prompt_and_handle_slash_commands(self, prompt: str, *, _multiline: bool = False) -> str: """Prompts the user, takes care of /h (followed by requery) and sets the mode. Returns the user input.""" console.print(prompt, end="") if _multiline: return _multiline_prompt() user_input = _prompt_session.prompt("") if user_input == "/m": return self._prompt_and_handle_slash_commands(prompt, _multiline=True) if user_input == "/h": console.print( f"Current mode: [bold green]{self.config.mode}[/bold green]\n" f"[bold green]/y[/bold green] to switch to [bold yellow]yolo[/bold yellow] mode (execute LM commands without confirmation)\n" f"[bold green]/c[/bold green] to switch to [bold yellow]confirmation[/bold yellow] mode (ask for confirmation before executing LM commands)\n" f"[bold green]/u[/bold green] to switch to [bold yellow]human[/bold yellow] mode (execute commands issued by the user)\n" f"[bold green]/m[/bold green] to enter multiline comment", ) return self._prompt_and_handle_slash_commands(prompt) if user_input in self._MODE_COMMANDS_MAPPING: if self.config.mode == self._MODE_COMMANDS_MAPPING[user_input]: return self._prompt_and_handle_slash_commands( f"[bold red]Already in {self.config.mode} mode.[/bold red]\n{prompt}" ) self.config.mode = self._MODE_COMMANDS_MAPPING[user_input] console.print(f"Switched to [bold green]{self.config.mode}[/bold green] mode.") return user_input return user_input