import json import logging import os import time from pathlib import Path from typing import Any, Literal import litellm from pydantic import BaseModel from minisweagent.models import GLOBAL_MODEL_STATS from minisweagent.models.utils.actions_toolcall_response import ( BASH_TOOL_RESPONSE_API, format_toolcall_observation_messages, parse_toolcall_actions_response, ) from minisweagent.models.utils.retry import retry logger = logging.getLogger("portkey_response_model") try: from portkey_ai import Portkey except ImportError: raise ImportError( "The portkey-ai package is required to use PortkeyResponseAPIModel. Please install it with: pip install portkey-ai" ) class PortkeyResponseAPIModelConfig(BaseModel): model_name: str model_kwargs: dict[str, Any] = {} litellm_model_registry: Path | str | None = os.getenv("LITELLM_MODEL_REGISTRY_PATH") litellm_model_name_override: str = "" cost_tracking: Literal["default", "ignore_errors"] = os.getenv("MSWEA_COST_TRACKING", "default") format_error_template: str = "{{ error }}" observation_template: str = ( "{% if output.exception_info %}{{output.exception_info}}\n{% endif %}" "{{output.returncode}}\n\n{{output.output}}" ) multimodal_regex: str = "" class PortkeyResponseAPIModel: """Portkey model using the Responses API with native tool calling. Note: This implementation is stateless - each request must include the full conversation history. previous_response_id is not used. """ abort_exceptions: list[type[Exception]] = [KeyboardInterrupt, TypeError, ValueError] def __init__(self, **kwargs): self.config = PortkeyResponseAPIModelConfig(**kwargs) if self.config.litellm_model_registry and Path(self.config.litellm_model_registry).is_file(): litellm.utils.register_model(json.loads(Path(self.config.litellm_model_registry).read_text())) self._api_key = os.getenv("PORTKEY_API_KEY") if not self._api_key: raise ValueError( "Portkey API key is required. Set it via the " "PORTKEY_API_KEY environment variable. You can permanently set it with " "`mini-extra config set PORTKEY_API_KEY YOUR_KEY`." ) virtual_key = os.getenv("PORTKEY_VIRTUAL_KEY") client_kwargs = {"api_key": self._api_key} if virtual_key: client_kwargs["virtual_key"] = virtual_key self.client = Portkey(**client_kwargs) def _query(self, messages: list[dict[str, str]], **kwargs): return self.client.responses.create( model=self.config.model_name, input=messages, tools=[BASH_TOOL_RESPONSE_API], **(self.config.model_kwargs | kwargs), ) def _prepare_messages_for_api(self, messages: list[dict]) -> list[dict]: """Prepare messages for Portkey's stateless Responses API. Flattens response objects into their output items. """ result = [] for msg in messages: if msg.get("object") == "response": for item in msg.get("output", []): result.append({k: v for k, v in item.items() if k != "extra"}) else: result.append({k: v for k, v in msg.items() if k != "extra"}) return result def query(self, messages: list[dict[str, str]], **kwargs) -> dict: for attempt in retry(logger=logger, abort_exceptions=self.abort_exceptions): with attempt: response = self._query(self._prepare_messages_for_api(messages), **kwargs) cost_output = self._calculate_cost(response) GLOBAL_MODEL_STATS.add(cost_output["cost"]) message = response.model_dump() if hasattr(response, "model_dump") else dict(response) message["extra"] = { "actions": self._parse_actions(response), **cost_output, "timestamp": time.time(), } return message def _parse_actions(self, response) -> list[dict]: """Parse tool calls from the response API response.""" output = response.output if hasattr(response, "output") else response.get("output", []) return parse_toolcall_actions_response(output, format_error_template=self.config.format_error_template) def _calculate_cost(self, response) -> dict[str, float]: try: cost = litellm.cost_calculator.completion_cost( response, model=self.config.litellm_model_name_override or self.config.model_name ) assert cost > 0.0, f"Cost is not positive: {cost}" except Exception as e: if self.config.cost_tracking != "ignore_errors": raise RuntimeError( f"Error calculating cost for model {self.config.model_name}: {e}. " "You can ignore this issue from your config file with cost_tracking: 'ignore_errors' or " "globally with export MSWEA_COST_TRACKING='ignore_errors' to ignore this error. " ) from e cost = 0.0 return {"cost": cost} def format_message(self, **kwargs) -> dict: role = kwargs.get("role", "user") content = kwargs.get("content", "") extra = kwargs.get("extra") content_items = [{"type": "input_text", "text": content}] if isinstance(content, str) else content msg = {"type": "message", "role": role, "content": content_items} if extra: msg["extra"] = extra return msg def format_observation_messages( self, message: dict, outputs: list[dict], template_vars: dict | None = None ) -> list[dict]: """Format execution outputs into tool result messages.""" actions = message.get("extra", {}).get("actions", []) return format_toolcall_observation_messages( actions=actions, outputs=outputs, observation_template=self.config.observation_template, template_vars=template_vars, multimodal_regex=self.config.multimodal_regex, ) def get_template_vars(self, **kwargs) -> dict: return self.config.model_dump() def serialize(self) -> dict: return { "info": { "config": { "model": self.config.model_dump(mode="json"), "model_type": f"{self.__class__.__module__}.{self.__class__.__name__}", }, } }