import json
from typing import Union, Any
from pyba.core.agent.base_agent import BaseAgent
from pyba.utils.load_yaml import load_config
from pyba.utils.prompts import planner_general_prompt_DFS, planner_general_prompt_BFS
from pyba.utils.structure import PlannerAgentOutputBFS, PlannerAgentOutputDFS
config = load_config("general")["main_engine_configs"]
[docs]
class PlannerAgent(BaseAgent):
"""
Planner agent for DFS and BFS modes under exploratory cases. This is inheriting off
from the Retry class as well and supports all agents under LLM_factory.
Args:
`engine`: Engine to hold all arguments provided by the user
Initialises the `max_breadth` for the maximum number of plans to generate for BFS mode
NOTE:
`context_id` is not relevant here because this is a higer level class
"""
def __init__(self, engine) -> None:
"""
Initialises the right agent from the LLMFactory
"""
super().__init__(engine=engine) # Initialising the base params from BaseAgent
self.agent = self.llm_factory.get_planner_agent()
self.max_breadth = config["max_breadth"]
def _initialise_prompt(self, task: str, old_plan: str = None):
"""
Initialise the prompt for the planner agent
Args:
`task`: Task given by the user
`old_plan`: The previous plan in case of DFS mode
"""
if self.mode == "BFS":
return planner_general_prompt_BFS.format(task=task, max_plans=self.max_breadth)
else:
return planner_general_prompt_DFS.format(task=task, old_plan=old_plan)
def _call_model(self, agent: Any, prompt: str) -> Any:
"""
Generic method to call the correct LLM provider and parse the response.
Args:
agent: The agent to use (action_agent or output_agent)
prompt: The fully formatted prompt string
Returns:
The parsed response (SimpleNamespace for action, str for output)
Uses the attempt_number to give ou
"""
if self.engine.provider == "openai":
response = self.handle_openai_execution(agent=agent, prompt=prompt)
parsed_json = json.loads(response.choices[0].message.content)
if "plans" in list(parsed_json.keys()):
return parsed_json["plans"]
if "plan" in list(parsed_json.keys()):
return parsed_json["plan"]
self.log.error("Parsed object has neither 'plans' nor 'plan' attribute.")
return None
elif self.engine.provider == "vertexai": # VertexAI logic
response = self.handle_vertexai_execution(agent=agent, prompt=prompt)
try:
parsed_object = getattr(
response, "output_parsed", getattr(response, "parsed", None)
)
if not parsed_object:
self.log.error("No parsed object found in VertexAI response.")
return None
if hasattr(parsed_object, "plans"):
return parsed_object.plans
if hasattr(parsed_object, "plan"):
return parsed_object.plan
self.log.error("Parsed object has neither 'plans' nor 'plan' attribute.")
return None
except Exception as e:
self.log.error(f"Unable to parse the output from VertexAI response: {e}")
return None
else: # Using gemini
response = self.handle_gemini_execution(agent=agent, prompt=prompt)
action = agent["response_format"].model_validate_json(response.text)
if hasattr(action, "plan"):
return action.plan
elif hasattr(action, "plans"):
return action.plans
else:
self.log.error("Parsed object has neither 'plans' nor 'plan' attribute.")
return None
[docs]
def generate(
self, task: str, old_plan: str = None
) -> Union[PlannerAgentOutputBFS, PlannerAgentOutputDFS]:
"""
Endpoint to generate the plan(s) depending on the set mode (the agent encodes the mode)
Args:
`task`: The task provided by the user
`old_plan`: The previous plan if using DFS mode
Function:
- Takes in the user prompt which serves as the task for the model to perform
- Depending on DFS or BFS mode generates plan(s)
"""
prompt = self._initialise_prompt(task=task, old_plan=old_plan)
return self._call_model(agent=self.agent, prompt=prompt)