Source code for pyba.core.agent.base_agent

import random
import time
from typing import Literal, Dict, List, Any

from pyba.core.agent.llm_factory import LLMFactory
from pyba.logger import get_logger


[docs] class BaseAgent: """ The base class for all Agents to define common methods Contains methods for exponential backoff and retry as well Note: this backoff and retry will be blocking for that specific context. Defines the following variables: `exponential_base`: 2 (we're using base 2) `base_timeout`: 1 second `max_backoff_time`: 60 seconds `attempt_number`: The current attempt number initialised to 1 `LLMFactory`: The internal agent call is made by agent itself `log`: The logger for the agents """ def __init__(self, engine): self.base = 2 self.base_timeout = 1 self.max_backoff_time = 60 self.engine = engine self.llm_factory = LLMFactory(engine=self.engine) self.log = get_logger() self.mode: Literal["Normal", "DFS", "BFS"] = self.engine.mode self.shared_depth_dictionary = {} def _initialise_prompt(self): """ Function to initialise prompts. This function needs to be impemented for each agent """ raise NotImplementedError("Subclasses must implement _initialise_prompt") def _initialise_openai_arguments( self, system_instruction: str, prompt: str, model_name: str ) -> Dict[str, List[Dict[str, str]]]: """ Initialises the arguments for OpenAI agents Args: `system_instruction`: The system instruction for the agent `prompt`: The current prompt for the agent `model_name`: The OpenAI model name Returns: An arguments dictionary which can be directly passed to OpenAI agents """ messages = [ {"role": "system", "content": system_instruction}, {"role": "user", "content": prompt}, ] kwargs = { "model": model_name, "messages": messages, } return kwargs
[docs] def handle_openai_execution(self, agent: Any, prompt: str, context_id: str = None): """ Helper method to handle OpenAI execution Args: `agent`: The agent to use (action_agent or output_agent) `prompt`: The fully formatted prompt string `context_id`: A unique identifier for the current browser window The `context_id` is to help in differentiating between different browser windows during parallel execution for BFS mode. `context_id`=None => There is only one browser session. Returns: `response`: The raw response from the model. The exact required values are expected to be extraced within each agent """ arguments = self._initialise_openai_arguments( system_instruction=agent["system_instruction"], prompt=prompt, model_name=agent["model"], ) while True: try: response = agent["client"].chat.completions.parse( **arguments, response_format=agent["response_format"] ) # self.attempt_number = 1 self.initialise_depth_ladder(unique_context_id=context_id) break except Exception: # If we hit a rate limit, calculate the time to wait and retry wait_time = self.calculate_next_time( self.shared_depth_dictionary.get(context_id, 1) ) self.log.warning(f"Hit the rate limit for OpenAI, retrying in {wait_time} seconds") time.sleep(wait_time) # wait_time is in seconds # self.attempt_number += 1 self.update_depth_ladder(unique_context_id=context_id) return response
[docs] def handle_vertexai_execution(self, agent: Any, prompt: str, context_id: str = None): """ Helper method to handle VertexAI execution Args: `agent`: The agent to use (action_agent or output_agent) `prompt`: The fully formatted prompt string `context_id`: A unique identifier for the current browser window The `context_id` is to help in differentiating between different browser windows during parallel execution for BFS mode. `context_id`=None => There is only one browser session. Returns: `response`: The raw response from the model. The exact required values are expected to be extraced within each agent """ while True: try: response = agent.send_message(prompt) self.initialise_depth_ladder(unique_context_id=context_id) break except Exception: wait_time = self.calculate_next_time( self.shared_depth_dictionary.get(context_id, 1) ) self.log.warning( f"Hit the rate limit for VertexAI, retrying in {wait_time} seconds" ) time.sleep(wait_time) self.update_depth_ladder(unique_context_id=context_id) return response
[docs] def handle_gemini_execution(self, agent: Any, prompt: str, context_id: str = None): """ Helper method to handle gemini's execution Args: `agent`: The agent to use (action_agent or output_agent) `prompt`: The fully formatted prompt string `context_id`: A unique identifier for the current browser window The `context_id` is to help in differentiating between different browser windows during parallel execution for BFS mode. `context_id`=None => There is only one browser session. Returns: `response`: The raw response from the model. The exact required values are expected to be extraced within each agent """ gemini_config = { "response_mime_type": "application/json", "response_json_schema": agent["response_format"].model_json_schema(), "system_instruction": agent["system_instruction"], } while True: try: response = agent["client"].models.generate_content( model=agent["model"], contents=prompt, config=gemini_config, ) self.initialise_depth_ladder(unique_context_id=context_id) break except Exception: wait_time = self.calculate_next_time( self.shared_depth_dictionary.get(context_id, 1) ) self.log.warning(f"Hit the rate limit for Gemini, retrying in {wait_time} seconds") time.sleep(wait_time) self.update_depth_ladder(unique_context_id=context_id) return response
[docs] def calculate_next_time(self, attempt_number): """ Function to calculate the next wait time in seconds Args: `attempt_number`: The number of failed attempts """ delay = self.base_timeout * (self.base ** (attempt_number - 1)) delay = min(delay, self.max_backoff_time) jitter = random.uniform(0, delay / 2) return delay + jitter
[docs] def initialise_depth_ladder(self, unique_context_id: str): """ Initialises and helps manage the depth-ladder for different browser sessions Args: `unique_context_id`: The context ID for the current browser session """ self.shared_depth_dictionary[unique_context_id] = 1
[docs] def update_depth_ladder(self, unique_context_id: str): """ This function helps increments the depth-value for each browser Args: `unique_context_id`: The context ID for the browser """ self.shared_depth_dictionary[unique_context_id] += 1