Source code for pyba.core.agent.planner_agent

import json
from typing import Union, Any

from pyba.core.agent.base_agent import BaseAgent
from pyba.utils.load_yaml import load_config
from pyba.utils.prompts import planner_general_prompt_DFS, planner_general_prompt_BFS
from pyba.utils.structure import PlannerAgentOutputBFS, PlannerAgentOutputDFS

config = load_config("general")["main_engine_configs"]


[docs] class PlannerAgent(BaseAgent): """ Planner agent for DFS and BFS exploration modes. Generates execution plans that are then carried out by the action agent. Args: engine: Engine instance holding all user-provided configuration. """ def __init__(self, engine) -> None: """ Initialises the right agent from the LLMFactory """ super().__init__(engine=engine) # Initialising the base params from BaseAgent self.agent = self.llm_factory.get_planner_agent() self.max_breadth = config["max_breadth"] def _initialise_prompt(self, task: str, old_plan: str = None): """ Formats the planner prompt based on the current mode (BFS or DFS). Args: task: The user's exploratory task. old_plan: The previous plan to diverge from (DFS mode only). """ if self.mode == "BFS": return planner_general_prompt_BFS.format(task=task, max_plans=self.max_breadth) else: return planner_general_prompt_DFS.format(task=task, old_plan=old_plan) def _call_model(self, agent: Any, prompt: str) -> Any: """ Generic method to call the correct LLM provider and parse the response. Args: agent: The agent to use (action_agent or output_agent) prompt: The fully formatted prompt string Returns: A plan string (DFS) or list of plan strings (BFS). """ if self.engine.provider == "openai": response = self.handle_openai_execution(agent=agent, prompt=prompt) parsed_json = json.loads(response.choices[0].message.content) if "plans" in list(parsed_json.keys()): return parsed_json["plans"] if "plan" in list(parsed_json.keys()): return parsed_json["plan"] self.log.error("Parsed object has neither 'plans' nor 'plan' attribute.") return None elif self.engine.provider == "vertexai": # VertexAI logic response = self.handle_vertexai_execution(agent=agent, prompt=prompt) try: parsed_object = getattr( response, "output_parsed", getattr(response, "parsed", None) ) if not parsed_object: self.log.error("No parsed object found in VertexAI response.") return None if hasattr(parsed_object, "plans"): return parsed_object.plans if hasattr(parsed_object, "plan"): return parsed_object.plan self.log.error("Parsed object has neither 'plans' nor 'plan' attribute.") return None except Exception as e: self.log.error(f"Unable to parse the output from VertexAI response: {e}") return None else: # Using gemini response = self.handle_gemini_execution(agent=agent, prompt=prompt) action = agent["response_format"].model_validate_json(response.text) if hasattr(action, "plan"): return action.plan elif hasattr(action, "plans"): return action.plans else: self.log.error("Parsed object has neither 'plans' nor 'plan' attribute.") return None
[docs] def generate( self, task: str, old_plan: str = None ) -> Union[PlannerAgentOutputBFS, PlannerAgentOutputDFS]: """ Generates exploration plan(s) based on the current mode. Args: task: The user's exploratory task. old_plan: The previous plan to diverge from (DFS mode only). Returns: A plan string (DFS) or list of plan strings (BFS). """ prompt = self._initialise_prompt(task=task, old_plan=old_plan) return self._call_model(agent=self.agent, prompt=prompt)