import litellm from minisweagent.models.litellm_model import LitellmModel, LitellmModelConfig from minisweagent.models.utils.actions_text import format_observation_messages, parse_regex_actions class LitellmTextbasedModelConfig(LitellmModelConfig): action_regex: str = r"```mswea_bash_command\s*\n(.*?)\n```" """Regex to extract the action from the LM's output.""" format_error_template: str = ( "Please always provide EXACTLY ONE action in triple backticks, found {{actions|length}} actions." ) """Template used when the LM's output is not in the expected format.""" class LitellmTextbasedModel(LitellmModel): def __init__(self, **kwargs): super().__init__(config_class=LitellmTextbasedModelConfig, **kwargs) def _query(self, messages: list[dict[str, str]], **kwargs): try: return litellm.completion( model=self.config.model_name, messages=messages, **(self.config.model_kwargs | kwargs) ) except litellm.exceptions.AuthenticationError as e: e.message += " You can permanently set your API key with `mini-extra config set KEY VALUE`." raise e def _parse_actions(self, response: dict) -> list[dict]: """Parse actions from the model response. Raises FormatError if not exactly one action.""" content = response.choices[0].message.content or "" return parse_regex_actions( content, action_regex=self.config.action_regex, format_error_template=self.config.format_error_template ) def format_observation_messages( self, message: dict, outputs: list[dict], template_vars: dict | None = None ) -> list[dict]: """Format execution outputs into observation messages.""" return format_observation_messages( outputs, observation_template=self.config.observation_template, template_vars=template_vars, multimodal_regex=self.config.multimodal_regex, )