import asyncio
import uuid
from typing import List, Union
from playwright.async_api import async_playwright
from playwright_stealth import Stealth
from pydantic import BaseModel
from pyba.core.agent import PlannerAgent
from pyba.core.lib.action import perform_action
from pyba.core.lib.mode.base import BaseEngine
from pyba.core.scripts import LoginEngine
from pyba.database import Database
from pyba.utils.common import ( # serialize_action kept for db pushes
initial_page_setup,
serialize_action,
)
from pyba.utils.exceptions import UnknownSiteChosen
from pyba.utils.load_yaml import load_config
from pyba.utils.structure import PasswordManager
config = load_config("general")
[docs]
class DFS(BaseEngine):
"""
Methods for handling DFS exploratory searches. The `BaseEngine` initialises
the provider and with that the playwright action and output agents.
This is another entry point engine and can be directly imported by the user.
The following params are defined:
Args:
openai_api_key: API key for OpenAI models should you want to use that
vertexai_project_id: Create a VertexAI project to use that instead of OpenAI
vertexai_server_location: VertexAI server location
gemini_api_key: API key for Gemini-2.5-pro native support without VertexAI
headless: Choose if you want to run in the headless mode or not
handle_dependencies: Choose if you want to automatically install dependencies during runtime
use_logger: Choose if you want to use the logger (that is enable logging of data)
max_depth: The maximum depth to go into for each plan, where each level of depth corresponds to an action
max_breadth: The number of plans to execute one by one in depth
enable_tracing: Choose if you want to enable tracing. This will create a .zip file which you can use in traceviewer
trace_save_directory: The directory where you want the .zip file to be saved
database: An instance of the Database class which will define all database specific configs
model_name: The model name which you want to run. The default is set to None (because it depends on the provider).
secrets: A password manager class which implements a resolve() method to give out a dictionary of secrets
Find these default values at `pyba/config.yaml`.
"""
def __init__(
self,
openai_api_key: str = None,
vertexai_project_id: str = None,
vertexai_server_location: str = None,
gemini_api_key: str = None,
headless: bool = config["main_engine_configs"]["headless_mode"],
handle_dependencies: bool = config["main_engine_configs"]["handle_dependencies"],
use_random: bool = config["main_engine_configs"]["use_random"],
use_logger: bool = config["main_engine_configs"]["use_logger"],
max_depth: int = config["main_engine_configs"]["max_depth"],
max_breadth: int = config["main_engine_configs"]["max_depth"],
enable_tracing: bool = config["main_engine_configs"]["enable_tracing"],
trace_save_directory: str = None,
database: Database = None,
model_name: str = None,
low_memory: bool = config["main_engine_configs"]["minimize_memory"],
secrets: PasswordManager = None,
enable_screenshots: bool = False,
screenshot_directory: str = None,
):
self.mode = "DFS"
# Passing the common setup to the BaseEngine
super().__init__(
headless=headless,
enable_tracing=enable_tracing,
trace_save_directory=trace_save_directory,
database=database,
use_random=use_random,
use_logger=use_logger,
mode=self.mode,
openai_api_key=openai_api_key,
vertexai_project_id=vertexai_project_id,
vertexai_server_location=vertexai_server_location,
gemini_api_key=gemini_api_key,
model_name=model_name,
low_memory=low_memory,
secrets=secrets,
enable_screenshots=enable_screenshots,
screenshot_directory=screenshot_directory,
)
# session_id is per-engine, not in BaseEngine, because BaseEngine is shared across modes
self.session_id = uuid.uuid4().hex
self.planner_agent = PlannerAgent(engine=self)
self.max_depth = max_depth
self.max_breadth = max_breadth
self.old_plan = None # A variable to hold the old plan for the planner agent to understand what has been done already
[docs]
async def run(
self,
prompt: str,
automated_login_sites: List[str] = None,
extraction_format: BaseModel = None,
) -> Union[str, None]:
"""
Run pyba in DFS mode.
Args:
prompt: The task assigned to DFS by the user
automated_login_sites: Login site name for pre-written scripts to run
extraction_format: A pydantic BaseModel which defines the extraction format for any data extraction
The task is fed into the planner to get a plan which is then passed to the action models
to fetch an actionable element.
"""
if automated_login_sites is not None:
assert isinstance(automated_login_sites, list), (
"Make sure the automated_login_sites is a list!"
)
for engine in automated_login_sites:
# Each engine is going to be a name like "instagram"
if hasattr(LoginEngine, engine):
engine_class = getattr(LoginEngine, engine)
self.automated_login_engine_classes.append(engine_class)
else:
raise UnknownSiteChosen(LoginEngine.available_engines())
try:
async with Stealth().use_async(async_playwright()) as p:
self.browser = await p.chromium.launch(**self._launch_kwargs)
self.context = await self.get_trace_context()
self.page = await self.context.new_page()
cleaned_dom = await initial_page_setup(self.page)
for steps in range(0, self.max_breadth):
plan = self.planner_agent.generate(task=prompt, old_plan=self.old_plan)
self.log.info(f"This is the plan for a DFS: {plan}")
for _ in range(0, self.max_depth):
login_attempted_successfully = await self.attempt_login()
if login_attempted_successfully:
cleaned_dom = await self.successful_login_clean_and_get_dom()
continue
action = self.fetch_action(
cleaned_dom=cleaned_dom.to_dict(),
user_prompt=plan,
action_history=self.mem.history,
extraction_format=extraction_format,
action_status=True,
fail_reason=None,
)
output = await self.generate_output(
action=action, cleaned_dom=cleaned_dom, prompt=plan
)
if output:
await self.save_trace()
await self.shut_down()
return output
value, fail_reason = await perform_action(self.page, action)
line = self.mem.record(
action, success=value is not None, fail_reason=fail_reason
)
self.log.action(line)
await self._capture_screenshot()
if value is None:
if self.db_funcs:
self.db_funcs.push_to_episodic_memory(
session_id=self.session_id,
action=serialize_action(action),
page_url=str(self.page.url),
action_status=False,
fail_reason=fail_reason,
)
cleaned_dom = await self.extract_dom()
output = await self.retry_perform_action(
cleaned_dom=cleaned_dom.to_dict(),
prompt=plan,
action_history=self.mem.history,
action_status=False,
fail_reason=fail_reason,
)
if output:
await self.save_trace()
await self.shut_down()
return output
else:
if self.db_funcs:
self.db_funcs.push_to_episodic_memory(
session_id=self.session_id,
action=serialize_action(action),
page_url=str(self.page.url),
action_status=True,
fail_reason=None,
)
cleaned_dom = await self.extract_dom()
self.log.warning(
"The maximum depth for the current plan has been reached, generating a new plan"
)
self.old_plan = plan
finally:
await self.save_trace()
await self.shut_down()
[docs]
def sync_run(
self,
prompt: str,
automated_login_sites: List[str] = None,
extraction_format: BaseModel = None,
) -> Union[str, None]:
"""
Sync endpoint for running the above function
"""
try:
output = asyncio.run(
self.run(
prompt=prompt,
automated_login_sites=automated_login_sites,
extraction_format=extraction_format,
)
)
if output:
return output
except KeyboardInterrupt:
# This is a forced shutdown, silently let it slip
pass