Source code for pyba.core.provider

from pyba.logger import get_logger
from pyba.utils.exceptions import (
    ServiceNotSelected,
    ServerLocationUndefined,
    UnsupportedModelUsed,
    InvalidModelSelected,
)
from pyba.utils.load_yaml import load_config

config = load_config("general")["main_engine_configs"]


[docs] class Provider: """ Class to handle the provider instances. """ def __init__( self, openai_api_key: str = None, gemini_api_key: str = None, vertexai_project_id: str = None, vertexai_server_location: str = None, model_name: str = None, ): """ Args: openai_api_key: API key for OpenAI models should you want to use that vertexai_project_id: Create a VertexAI project to use that instead of OpenAI vertexai_server_location: VertexAI server location model_name: Model name to use gemini_api_key: API key for Gemini models """ self.provider: str | None = None self.model: str | None = None self.openai_api_key: str | None = openai_api_key self.vertexai_project_id: str | None = vertexai_project_id self.gemini_api_key: str | None = gemini_api_key self.location: str | None = vertexai_server_location self.model_name: str | None = model_name self.log = get_logger() self.handle_keys() # This figures out the provider we're using and validates that self.handle_model(provider=self.provider)
[docs] def handle_keys(self): """ Handles provider selection, defaults to openai when multiple providers conflict """ if ( self.openai_api_key is None and self.vertexai_project_id is None and self.gemini_api_key is None ): raise ServiceNotSelected() if self.vertexai_project_id and self.location is None: raise ServerLocationUndefined(self.location) if ( (self.openai_api_key and self.vertexai_project_id) or (self.vertexai_project_id and self.gemini_api_key) or (self.openai_api_key and self.gemini_api_key) ): if self.openai_api_key: self.log.warning("Multiple LLM keys defined, defaulting to OpenAI") self.provider = config["openai"]["provider"] self.vertexai_project_id = None self.location = None self.gemini_api_key = None elif self.vertexai_project_id: self.log.warning("Multiple LLM keys defined, defaulting to VertexAI") self.provider = config["vertexai"]["provider"] self.gemini_api_key = None else: self.log.warning("Multiple LLM keys defined, defaulting to Gemini") self.provider = config["gemini"]["provider"] self.vertexai_project_id = None self.location = None elif self.vertexai_project_id: self.provider = config["vertexai"]["provider"] elif self.openai_api_key: self.provider = config["openai"]["provider"] else: self.provider = config["gemini"]["provider"]
[docs] def handle_model(self, provider: str): """ Helper function that manages model selection based on the keys chosen. Note: The default models in config will be used if model name is not provided by the user. The list of valid model names will be present in the config file as well. Args: provider: The name of the provider in question """ if not self.model_name: # Default model based on provider self.model = config[self.provider]["model"] return self.valid_models: dict = { "vertexai_models": config["vertexai"]["available_models"], "openai_models": config["openai"]["available_models"], "gemini_models": config["gemini"]["available_models"], } if self.model_name not in [j for model in self.valid_models.values() for j in model]: raise UnsupportedModelUsed( model_name=self.model_name, valid_model_names=[j for model in self.valid_models.values() for j in model], ) if self.model_name not in self.valid_models.get(f"{self.provider}_models"): raise InvalidModelSelected( model_name=self.model_name, provider=self.provider, provider_valid_models=self.valid_models.get(self.provider), ) else: self.model = self.model_name