danswer/backend/onyx/tools/tool_constructor.py
Chris Weaver 420476ad92
Add basic passthrough auth (#3731)
* Add basic passthrough auth

* Add server-side validation

* Disallow for non-oauth

* Fix npm build
2025-01-20 23:39:23 -08:00

272 lines
10 KiB
Python

from typing import cast
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig
from onyx.configs.app_configs import AZURE_DALLE_API_BASE
from onyx.configs.app_configs import AZURE_DALLE_API_KEY
from onyx.configs.app_configs import AZURE_DALLE_API_VERSION
from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from onyx.configs.chat_configs import BING_API_KEY
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.tools.built_in_tools import get_built_in_tool_by_id
from onyx.tools.models import DynamicSchemaInfo
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import compute_all_tool_tokens
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.headers import header_dict_to_header_list
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
"""Helper function to get image generation LLM config based on available providers"""
if llm and llm.config.api_key and llm.config.model_provider == "openai":
return LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
if llm.config.model_provider == "azure" and AZURE_DALLE_API_KEY is not None:
return LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
# Fallback to checking for OpenAI provider in database
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError("Image generation tool requires an OpenAI API key")
return LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
class SearchToolConfig(BaseModel):
answer_style_config: AnswerStyleConfig = Field(
default_factory=lambda: AnswerStyleConfig(citation_config=CitationConfig())
)
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
selected_sections: list[InferenceSection] | None = None
chunks_above: int = 0
chunks_below: int = 0
full_doc: bool = False
latest_query_files: list[InMemoryChatFile] | None = None
# Use with care, should only be used for OnyxBot in channels with multiple users
bypass_acl: bool = False
class InternetSearchToolConfig(BaseModel):
answer_style_config: AnswerStyleConfig = Field(
default_factory=lambda: AnswerStyleConfig(
citation_config=CitationConfig(all_docs_useful=True)
)
)
class ImageGenerationToolConfig(BaseModel):
additional_headers: dict[str, str] | None = None
class CustomToolConfig(BaseModel):
chat_session_id: UUID | None = None
message_id: int | None = None
additional_headers: dict[str, str] | None = None
def construct_tools(
persona: Persona,
prompt_config: PromptConfig,
db_session: Session,
user: User | None,
llm: LLM,
fast_llm: LLM,
search_tool_config: SearchToolConfig | None = None,
internet_search_tool_config: InternetSearchToolConfig | None = None,
image_generation_tool_config: ImageGenerationToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
) -> dict[int, list[Tool]]:
"""Constructs tools based on persona configuration and available APIs"""
tool_dict: dict[int, list[Tool]] = {}
# Get user's OAuth token if available
user_oauth_token = None
if user and user.oauth_accounts:
user_oauth_token = user.oauth_accounts[0].access_token
for db_tool_model in persona.tools:
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(
db_tool_model.in_code_tool_id, db_session
)
# Handle Search Tool
if tool_cls.__name__ == SearchTool.__name__:
if not search_tool_config:
search_tool_config = SearchToolConfig()
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
tool_dict[db_tool_model.id] = [search_tool]
# Handle Image Generation Tool
elif tool_cls.__name__ == ImageGenerationTool.__name__:
if not image_generation_tool_config:
image_generation_tool_config = ImageGenerationToolConfig()
img_generation_llm_config = _get_image_generation_config(
llm, db_session
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=image_generation_tool_config.additional_headers,
model=img_generation_llm_config.model_name,
)
]
# Handle Internet Search Tool
elif tool_cls.__name__ == InternetSearchTool.__name__:
if not internet_search_tool_config:
internet_search_tool_config = InternetSearchToolConfig()
if not BING_API_KEY:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Onyx admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=BING_API_KEY,
answer_style_config=internet_search_tool_config.answer_style_config,
prompt_config=prompt_config,
)
]
# Handle custom tools
elif db_tool_model.openapi_schema:
if not custom_tool_config:
custom_tool_config = CustomToolConfig()
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=custom_tool_config.chat_session_id,
message_id=custom_tool_config.message_id,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_config.additional_headers or {}
)
),
user_oauth_token=(
user_oauth_token if db_tool_model.passthrough_auth else None
),
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
if search_tool_config:
search_tool_config.document_pruning_config.tool_num_tokens = (
compute_all_tool_tokens(
tools,
get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
),
)
)
search_tool_config.document_pruning_config.using_tool_message = (
explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
)
return tool_dict