import json import re import threading from pathlib import Path import pytest from minisweagent.models import GLOBAL_MODEL_STATS def pytest_addoption(parser): """Add custom command line options.""" parser.addoption( "--run-fire", action="store_true", default=False, help="Run fire tests (real API calls that cost money)", ) # Global lock for tests that modify global state - this works across threads _global_stats_lock = threading.Lock() @pytest.fixture def reset_global_stats(): """Reset global model stats and ensure exclusive access for tests that need it. This fixture should be used by any test that depends on global model stats to ensure thread safety and test isolation. """ with _global_stats_lock: # Reset at start GLOBAL_MODEL_STATS._cost = 0.0 # noqa: protected-access GLOBAL_MODEL_STATS._n_calls = 0 # noqa: protected-access yield # Reset at end to clean up GLOBAL_MODEL_STATS._cost = 0.0 # noqa: protected-access GLOBAL_MODEL_STATS._n_calls = 0 # noqa: protected-access def get_test_data(trajectory_name: str) -> dict[str, list[str]]: """Load test fixtures from a trajectory JSON file""" json_path = Path(__file__).parent / "test_data" / f"{trajectory_name}.traj.json" with json_path.open() as f: trajectory = json.load(f) # Extract model responses (assistant messages, starting from index 2) model_responses = [] # Extract expected observations (user messages, starting from index 3) expected_observations = [] for i, message in enumerate(trajectory): if i < 2: # Skip system message (0) and initial user message (1) continue if message["role"] == "assistant": model_responses.append(message["content"]) elif message["role"] == "user": expected_observations.append(message["content"]) return {"model_responses": model_responses, "expected_observations": expected_observations} def normalize_outputs(s: str) -> str: """Strip leading/trailing whitespace and normalize internal whitespace""" # Remove everything between and , because this contains docker container ids s = re.sub(r"(.*?)", "", s, flags=re.DOTALL) # Replace all lines that have root in them because they tend to appear with times s = "\n".join(l for l in s.split("\n") if "root root" not in l) return "\n".join(line.rstrip() for line in s.strip().split("\n")) def assert_observations_match(expected_observations: list[str], messages: list[dict]) -> None: """Compare expected observations with actual observations from agent messages Args: expected_observations: List of expected observation strings messages: Agent conversation messages (list of message dicts with 'role' and 'content') """ # Extract actual observations from agent messages # User/exit messages (observations) are at indices 3, 5, 7, etc. actual_observations = [] for i in range(len(expected_observations)): user_message_index = 3 + (i * 2) assert messages[user_message_index]["role"] in ("user", "exit") actual_observations.append(messages[user_message_index]["content"]) assert len(actual_observations) == len(expected_observations), ( f"Expected {len(expected_observations)} observations, got {len(actual_observations)}" ) for i, (expected_observation, actual_observation) in enumerate(zip(expected_observations, actual_observations)): normalized_actual = normalize_outputs(actual_observation) normalized_expected = normalize_outputs(expected_observation) assert normalized_actual == normalized_expected, ( f"Step {i + 1} observation mismatch:\nExpected: {repr(normalized_expected)}\nActual: {repr(normalized_actual)}" ) @pytest.fixture def github_test_data(): """Load GitHub issue test fixtures""" return get_test_data("github_issue") @pytest.fixture def local_test_data(): """Load local test fixtures""" return get_test_data("local")