Source code for pyba.core.lib.mode.DFS

import asyncio
import uuid
from typing import List, Union

from playwright.async_api import async_playwright
from playwright_stealth import Stealth
from pydantic import BaseModel

from pyba.core.agent import PlannerAgent
from pyba.core.lib.action import perform_action
from pyba.core.lib.mode.base import BaseEngine
from pyba.core.scripts import LoginEngine
from pyba.database import Database
from pyba.utils.common import (  # serialize_action kept for db pushes
    initial_page_setup,
    serialize_action,
)
from pyba.utils.exceptions import UnknownSiteChosen
from pyba.utils.load_yaml import load_config
from pyba.utils.structure import PasswordManager

config = load_config("general")


[docs] class DFS(BaseEngine): """ Methods for handling DFS exploratory searches. The `BaseEngine` initialises the provider and with that the playwright action and output agents. This is another entry point engine and can be directly imported by the user. The following params are defined: Args: openai_api_key: API key for OpenAI models should you want to use that vertexai_project_id: Create a VertexAI project to use that instead of OpenAI vertexai_server_location: VertexAI server location gemini_api_key: API key for Gemini-2.5-pro native support without VertexAI headless: Choose if you want to run in the headless mode or not handle_dependencies: Choose if you want to automatically install dependencies during runtime use_logger: Choose if you want to use the logger (that is enable logging of data) max_depth: The maximum depth to go into for each plan, where each level of depth corresponds to an action max_breadth: The number of plans to execute one by one in depth enable_tracing: Choose if you want to enable tracing. This will create a .zip file which you can use in traceviewer trace_save_directory: The directory where you want the .zip file to be saved database: An instance of the Database class which will define all database specific configs model_name: The model name which you want to run. The default is set to None (because it depends on the provider). secrets: A password manager class which implements a resolve() method to give out a dictionary of secrets Find these default values at `pyba/config.yaml`. """ def __init__( self, openai_api_key: str = None, vertexai_project_id: str = None, vertexai_server_location: str = None, gemini_api_key: str = None, headless: bool = config["main_engine_configs"]["headless_mode"], handle_dependencies: bool = config["main_engine_configs"]["handle_dependencies"], use_random: bool = config["main_engine_configs"]["use_random"], use_logger: bool = config["main_engine_configs"]["use_logger"], max_depth: int = config["main_engine_configs"]["max_depth"], max_breadth: int = config["main_engine_configs"]["max_depth"], enable_tracing: bool = config["main_engine_configs"]["enable_tracing"], trace_save_directory: str = None, database: Database = None, model_name: str = None, low_memory: bool = config["main_engine_configs"]["minimize_memory"], secrets: PasswordManager = None, enable_screenshots: bool = False, screenshot_directory: str = None, ): self.mode = "DFS" # Passing the common setup to the BaseEngine super().__init__( headless=headless, enable_tracing=enable_tracing, trace_save_directory=trace_save_directory, database=database, use_random=use_random, use_logger=use_logger, mode=self.mode, openai_api_key=openai_api_key, vertexai_project_id=vertexai_project_id, vertexai_server_location=vertexai_server_location, gemini_api_key=gemini_api_key, model_name=model_name, low_memory=low_memory, secrets=secrets, enable_screenshots=enable_screenshots, screenshot_directory=screenshot_directory, ) # session_id is per-engine, not in BaseEngine, because BaseEngine is shared across modes self.session_id = uuid.uuid4().hex self.planner_agent = PlannerAgent(engine=self) self.max_depth = max_depth self.max_breadth = max_breadth self.old_plan = None # A variable to hold the old plan for the planner agent to understand what has been done already
[docs] async def run( self, prompt: str, automated_login_sites: List[str] = None, extraction_format: BaseModel = None, ) -> Union[str, None]: """ Run pyba in DFS mode. Args: prompt: The task assigned to DFS by the user automated_login_sites: Login site name for pre-written scripts to run extraction_format: A pydantic BaseModel which defines the extraction format for any data extraction The task is fed into the planner to get a plan which is then passed to the action models to fetch an actionable element. """ if automated_login_sites is not None: assert isinstance(automated_login_sites, list), ( "Make sure the automated_login_sites is a list!" ) for engine in automated_login_sites: # Each engine is going to be a name like "instagram" if hasattr(LoginEngine, engine): engine_class = getattr(LoginEngine, engine) self.automated_login_engine_classes.append(engine_class) else: raise UnknownSiteChosen(LoginEngine.available_engines()) try: async with Stealth().use_async(async_playwright()) as p: self.browser = await p.chromium.launch(**self._launch_kwargs) self.context = await self.get_trace_context() self.page = await self.context.new_page() cleaned_dom = await initial_page_setup(self.page) for steps in range(0, self.max_breadth): plan = self.planner_agent.generate(task=prompt, old_plan=self.old_plan) self.log.info(f"This is the plan for a DFS: {plan}") for _ in range(0, self.max_depth): login_attempted_successfully = await self.attempt_login() if login_attempted_successfully: cleaned_dom = await self.successful_login_clean_and_get_dom() continue action = self.fetch_action( cleaned_dom=cleaned_dom.to_dict(), user_prompt=plan, action_history=self.mem.history, extraction_format=extraction_format, action_status=True, fail_reason=None, ) output = await self.generate_output( action=action, cleaned_dom=cleaned_dom, prompt=plan ) if output: await self.save_trace() await self.shut_down() return output value, fail_reason = await perform_action(self.page, action) line = self.mem.record( action, success=value is not None, fail_reason=fail_reason ) self.log.action(line) await self._capture_screenshot() if value is None: if self.db_funcs: self.db_funcs.push_to_episodic_memory( session_id=self.session_id, action=serialize_action(action), page_url=str(self.page.url), action_status=False, fail_reason=fail_reason, ) cleaned_dom = await self.extract_dom() output = await self.retry_perform_action( cleaned_dom=cleaned_dom.to_dict(), prompt=plan, action_history=self.mem.history, action_status=False, fail_reason=fail_reason, ) if output: await self.save_trace() await self.shut_down() return output else: if self.db_funcs: self.db_funcs.push_to_episodic_memory( session_id=self.session_id, action=serialize_action(action), page_url=str(self.page.url), action_status=True, fail_reason=None, ) cleaned_dom = await self.extract_dom() self.log.warning( "The maximum depth for the current plan has been reached, generating a new plan" ) self.old_plan = plan finally: await self.save_trace() await self.shut_down()
[docs] def sync_run( self, prompt: str, automated_login_sites: List[str] = None, extraction_format: BaseModel = None, ) -> Union[str, None]: """ Sync endpoint for running the above function """ try: output = asyncio.run( self.run( prompt=prompt, automated_login_sites=automated_login_sites, extraction_format=extraction_format, ) ) if output: return output except KeyboardInterrupt: # This is a forced shutdown, silently let it slip pass