import logging
import time
from typing import Any
from pydantic import BaseModel
from minisweagent.models import GLOBAL_MODEL_STATS
from minisweagent.models.utils.actions_text import format_observation_messages
from minisweagent.models.utils.actions_toolcall import format_toolcall_observation_messages
from minisweagent.models.utils.actions_toolcall_response import (
format_toolcall_observation_messages as format_response_api_observation_messages,
)
from minisweagent.models.utils.openai_multimodal import expand_multimodal_content
def make_output(content: str, actions: list[dict], cost: float = 1.0) -> dict:
"""Helper to create an output dict for DeterministicModel.
Args:
content: The response content string
actions: List of action dicts, e.g., [{"command": "echo hello"}]
cost: Cost to report for this output (default 1.0)
"""
return {
"role": "assistant",
"content": content,
"extra": {"actions": actions, "cost": cost, "timestamp": time.time()},
}
def make_toolcall_output(content: str | None, tool_calls: list[dict], actions: list[dict]) -> dict:
"""Helper to create a toolcall output dict for DeterministicToolcallModel.
Args:
content: Optional text content (can be None for tool-only responses)
tool_calls: List of tool call dicts in OpenAI format
actions: List of parsed action dicts, e.g., [{"command": "echo hello", "tool_call_id": "call_123"}]
"""
return {
"role": "assistant",
"content": content,
"tool_calls": tool_calls,
"extra": {"actions": actions, "cost": 1.0, "timestamp": time.time()},
}
def make_response_api_output(content: str | None, actions: list[dict]) -> dict:
"""Helper to create an output dict for DeterministicResponseAPIToolcallModel.
Args:
content: Optional text content (can be None for tool-only responses)
actions: List of action dicts with 'command' and 'tool_call_id' keys
"""
output_items = []
if content:
output_items.append(
{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]}
)
for action in actions:
output_items.append(
{
"type": "function_call",
"call_id": action["tool_call_id"],
"name": "bash",
"arguments": f'{{"command": "{action["command"]}"}}',
}
)
return {
"object": "response",
"output": output_items,
"extra": {"actions": actions, "cost": 1.0, "timestamp": time.time()},
}
def _process_test_actions(actions: list[dict]) -> bool:
"""Process special test actions. Returns True if the query should be retried."""
for action in actions:
if "raise" in action:
raise action["raise"]
cmd = action.get("command", "")
if cmd.startswith("/sleep "):
time.sleep(float(cmd.split("/sleep ")[1]))
return True
if cmd.startswith("/warning"):
logging.warning(cmd.split("/warning")[1])
return True
return False
class DeterministicModelConfig(BaseModel):
outputs: list[dict]
"""List of exact output messages to return in sequence. Each dict should have 'role', 'content', and 'extra' (with 'actions')."""
model_name: str = "deterministic"
cost_per_call: float = 1.0
observation_template: str = (
"{% if output.exception_info %}{{output.exception_info}}\n{% endif %}"
"{{output.returncode}}\n"
)
"""Template used to render the observation after executing an action."""
multimodal_regex: str = ""
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
class DeterministicModel:
def __init__(self, **kwargs):
"""Initialize with a list of output messages to return in sequence."""
self.config = DeterministicModelConfig(**kwargs)
self.current_index = -1
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
self.current_index += 1
output = self.config.outputs[self.current_index]
if _process_test_actions(output.get("extra", {}).get("actions", [])):
return self.query(messages, **kwargs)
GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
return output
def format_message(self, **kwargs) -> dict:
return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
def format_observation_messages(
self, message: dict, outputs: list[dict], template_vars: dict | None = None
) -> list[dict]:
"""Format execution outputs into observation messages."""
return format_observation_messages(
outputs,
observation_template=self.config.observation_template,
template_vars=template_vars,
multimodal_regex=self.config.multimodal_regex,
)
def get_template_vars(self, **kwargs) -> dict[str, Any]:
return self.config.model_dump()
def serialize(self) -> dict:
return {
"info": {
"config": {
"model": self.config.model_dump(mode="json"),
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
},
}
}
class DeterministicToolcallModelConfig(BaseModel):
outputs: list[dict]
"""List of exact output messages with tool_calls to return in sequence."""
model_name: str = "deterministic_toolcall"
cost_per_call: float = 1.0
observation_template: str = (
"{% if output.exception_info %}{{output.exception_info}}\n{% endif %}"
"{{output.returncode}}\n"
)
"""Template used to render the observation after executing an action."""
multimodal_regex: str = ""
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
class DeterministicToolcallModel:
def __init__(self, **kwargs):
"""Initialize with a list of toolcall output messages to return in sequence."""
self.config = DeterministicToolcallModelConfig(**kwargs)
self.current_index = -1
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
self.current_index += 1
output = self.config.outputs[self.current_index]
if _process_test_actions(output.get("extra", {}).get("actions", [])):
return self.query(messages, **kwargs)
GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
return output
def format_message(self, **kwargs) -> dict:
return expand_multimodal_content(kwargs, pattern=self.config.multimodal_regex)
def format_observation_messages(
self, message: dict, outputs: list[dict], template_vars: dict | None = None
) -> list[dict]:
"""Format execution outputs into tool result messages."""
actions = message.get("extra", {}).get("actions", [])
return format_toolcall_observation_messages(
actions=actions,
outputs=outputs,
observation_template=self.config.observation_template,
template_vars=template_vars,
multimodal_regex=self.config.multimodal_regex,
)
def get_template_vars(self, **kwargs) -> dict[str, Any]:
return self.config.model_dump()
def serialize(self) -> dict:
return {
"info": {
"config": {
"model": self.config.model_dump(mode="json"),
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
},
}
}
class DeterministicResponseAPIToolcallModelConfig(BaseModel):
outputs: list[dict]
"""List of exact Response API output messages to return in sequence."""
model_name: str = "deterministic_response_api_toolcall"
cost_per_call: float = 1.0
observation_template: str = (
"{% if output.exception_info %}{{output.exception_info}}\n{% endif %}"
"{{output.returncode}}\n"
)
"""Template used to render the observation after executing an action."""
multimodal_regex: str = ""
"""Regex to extract multimodal content. Empty string disables multimodal processing."""
class DeterministicResponseAPIToolcallModel:
"""Deterministic test model using OpenAI Responses API format."""
def __init__(self, **kwargs):
"""Initialize with a list of Response API output messages to return in sequence."""
self.config = DeterministicResponseAPIToolcallModelConfig(**kwargs)
self.current_index = -1
def query(self, messages: list[dict[str, str]], **kwargs) -> dict:
self.current_index += 1
output = self.config.outputs[self.current_index]
if _process_test_actions(output.get("extra", {}).get("actions", [])):
return self.query(messages, **kwargs)
GLOBAL_MODEL_STATS.add(self.config.cost_per_call)
return output
def format_message(self, **kwargs) -> dict:
"""Format message in Responses API format."""
role = kwargs.get("role", "user")
content = kwargs.get("content", "")
extra = kwargs.get("extra")
content_items = [{"type": "input_text", "text": content}] if isinstance(content, str) else content
msg: dict = {"type": "message", "role": role, "content": content_items}
if extra:
msg["extra"] = extra
return msg
def format_observation_messages(
self, message: dict, outputs: list[dict], template_vars: dict | None = None
) -> list[dict]:
"""Format execution outputs into function_call_output messages."""
actions = message.get("extra", {}).get("actions", [])
return format_response_api_observation_messages(
actions=actions,
outputs=outputs,
observation_template=self.config.observation_template,
template_vars=template_vars,
multimodal_regex=self.config.multimodal_regex,
)
def get_template_vars(self, **kwargs) -> dict[str, Any]:
return self.config.model_dump()
def serialize(self) -> dict:
return {
"info": {
"config": {
"model": self.config.model_dump(mode="json"),
"model_type": f"{self.__class__.__module__}.{self.__class__.__name__}",
},
}
}