Source code for pyba.core.lib.mode.BFS

import asyncio
import uuid
from typing import List, Union

from playwright.async_api import async_playwright
from playwright_stealth import Stealth
from pydantic import BaseModel

from pyba.core.agent import PlannerAgent
from pyba.core.helpers.mem_dsl import MemDSL
from pyba.core.lib.action import perform_action
from pyba.core.lib.mode.base import BaseEngine
from pyba.core.scripts import LoginEngine
from pyba.database import Database
from pyba.utils.common import (  # serialize_action kept for db pushes
    initial_page_setup,
    serialize_action,
)
from pyba.utils.exceptions import UnknownSiteChosen
from pyba.utils.load_yaml import load_config
from pyba.utils.structure import PasswordManager

config = load_config("general")


[docs] class BFS(BaseEngine): """ Methods for handling BFS exploratory searches. The `BaseEngine` initialises the provider and with that the playwright action and output agents. This is another entry point engine and can be directly imported by the user. The following params are defined: Args: openai_api_key: API key for OpenAI models should you want to use that vertexai_project_id: Create a VertexAI project to use that instead of OpenAI vertexai_server_location: VertexAI server location gemini_api_key: API key for Gemini-2.5-pro native support without VertexAI headless: Choose if you want to run in the headless mode or not handle_dependencies: Choose if you want to automatically install dependencies during runtime use_logger: Choose if you want to use the logger (that is enable logging of data) max_depth: The maximum depth to go into for each plan, where each level of depth corresponds to an action max_breadth: The number of plans to execute one by one in depth enable_tracing: Choose if you want to enable tracing. This will create a .zip file which you can use in traceviewer trace_save_directory: The directory where you want the .zip file to be saved database: An instance of the Database class which will define all database specific configs model_name: The model name which you want to run. The default is set to None (because it depends on the provider). secrets: A password manager class which implements a resolve() method to give out a dictionary of secrets Find these default values at `pyba/config.yaml`. """ def __init__( self, openai_api_key: str = None, vertexai_project_id: str = None, vertexai_server_location: str = None, gemini_api_key: str = None, headless: bool = config["main_engine_configs"]["headless_mode"], handle_dependencies: bool = config["main_engine_configs"]["handle_dependencies"], use_logger: bool = config["main_engine_configs"]["use_logger"], max_depth: int = config["main_engine_configs"]["max_depth"], max_breadth: int = config["main_engine_configs"]["max_breadth"], enable_tracing: bool = config["main_engine_configs"]["enable_tracing"], trace_save_directory: str = None, database: Database = None, model_name: str = None, low_memory: bool = config["main_engine_configs"]["minimize_memory"], secrets: PasswordManager = None, enable_screenshots: bool = False, screenshot_directory: str = None, ): self.mode = "BFS" # Passing the common setup to the BaseEngine super().__init__( headless=headless, enable_tracing=enable_tracing, trace_save_directory=trace_save_directory, database=database, use_logger=use_logger, mode=self.mode, openai_api_key=openai_api_key, vertexai_project_id=vertexai_project_id, vertexai_server_location=vertexai_server_location, gemini_api_key=gemini_api_key, model_name=model_name, low_memory=low_memory, secrets=secrets, enable_screenshots=enable_screenshots, screenshot_directory=screenshot_directory, ) # session_id is per-engine, not in BaseEngine, because BaseEngine is shared across modes self.session_id = uuid.uuid4().hex self.planner_agent = PlannerAgent(engine=self) self.max_depth = max_depth self.max_breadth = max_breadth async def _run( self, task: str, extraction_format: BaseModel = None, context_id: str = None ) -> Union[str, None]: """ helper run function for BFS Args: task: A singular task which needs to be performed extraction_format: The extraction format for the required goal context_id: A dynamically generated context-id for each browser window Since BFS generates multiple browser windows at runtime, each gets its own context ID to manage individual exponential retries and logging. """ try: async with Stealth().use_async(async_playwright()) as p: browser = await p.chromium.launch(**self._launch_kwargs) context = await self.get_trace_context(browser_instance=browser) page = await context.new_page() cleaned_dom = await initial_page_setup(page) mem = MemDSL() for _ in range(0, self.max_depth): login_attempted_successfully = await self.attempt_login(page) if login_attempted_successfully: cleaned_dom = await self.successful_login_clean_and_get_dom() continue action = self.fetch_action( cleaned_dom=cleaned_dom.to_dict(), user_prompt=task, action_history=mem.history, extraction_format=extraction_format, context_id=context_id, action_status=True, fail_reason=None, ) output = await self.generate_output( action=action, cleaned_dom=cleaned_dom, prompt=task ) if output: await self.save_trace() await self.shut_down() return output value, fail_reason = await perform_action(page, action) line = mem.record(action, success=value is not None, fail_reason=fail_reason) self.log.action(line) await self._capture_screenshot(page) if value is None: if self.db_funcs: self.db_funcs.push_to_bfs_episodic_memory( session_id=self.session_id, context_id=context_id, action=serialize_action(action), page_url=str(page.url), ) cleaned_dom = await self.extract_dom(page) output = await self.retry_perform_action( cleaned_dom=cleaned_dom.to_dict(), prompt=task, action_history=mem.history, action_status=False, fail_reason=fail_reason, page=page, mem=mem, ) if output: await self.save_trace(context) await self.shut_down(context, browser) return output else: if self.db_funcs: self.db_funcs.push_to_bfs_episodic_memory( session_id=self.session_id, context_id=context_id, action=serialize_action(action), page_url=str(page.url), ) cleaned_dom = await self.extract_dom(page) self.log.warning( "The maximum depth for the current task has been reached, generating a new plan to achieve this task" ) finally: await self.save_trace(context) await self.shut_down(context, browser)
[docs] async def run( self, prompt: str, automated_login_sites: List[str] = None, extraction_format: BaseModel = None, ) -> List: """ The async run function Args: prompt: The prompt which needs to be converted to plans automated_login_sites: List of names for which sites to login automatically extraction_format: The extraction format for any extraction that needs to be done Returns: List """ if automated_login_sites is not None: assert isinstance(automated_login_sites, list), ( "Make sure the automated_login_sites is a list!" ) for engine in automated_login_sites: # Each engine is going to be a name like "instagram" if hasattr(LoginEngine, engine): engine_class = getattr(LoginEngine, engine) self.automated_login_engine_classes.append(engine_class) else: raise UnknownSiteChosen(LoginEngine.available_engines()) plan_list = self.planner_agent.generate(task=prompt) assert isinstance(plan_list, list), ( f"Expected the plan to be a list, got {type(plan_list)} instead." ) self.log.info(f"This is the plan for a BFS: {plan_list}") # Keeping this purely async is better for playwright tasks = [] for task in plan_list: context_id = uuid.uuid4().hex # This should do it. tasks.append(asyncio.create_task(self._run(task, extraction_format, context_id))) results = await asyncio.gather(*tasks, return_exceptions=False) return results
[docs] def sync_run( self, prompt: str, automated_login_sites: List[str] = None, extraction_format: BaseModel = None, ): """ Synchronous endpoint for running BFS mode. Args: prompt: The prompt which needs to be converted to plans automated_login_sites: List of names for which sites to login automatically extraction_format: The extraction format for any extraction that needs to be done """ try: output = asyncio.run( self.run( prompt=prompt, automated_login_sites=automated_login_sites, extraction_format=extraction_format, ) ) if output: return output except KeyboardInterrupt: # This is a forced shutdown, silently let it slip pass