From 8d62b992ef4391181b82c1fe9840be03bbf11fee Mon Sep 17 00:00:00 2001 From: pablonyx Date: Tue, 28 Jan 2025 09:38:32 -0800 Subject: [PATCH] Double check all chat accessible dependencies (#3801) * double check all chat accessible dependencies * k * k * k * k * k * k --- backend/onyx/db/llm.py | 23 ++++++++++++++++++++++- backend/onyx/server/manage/llm/api.py | 5 ++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/backend/onyx/db/llm.py b/backend/onyx/db/llm.py index b5eb23428..eff919295 100644 --- a/backend/onyx/db/llm.py +++ b/backend/onyx/db/llm.py @@ -3,6 +3,8 @@ from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.orm import Session +from onyx.configs.app_configs import AUTH_TYPE +from onyx.configs.constants import AuthType from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from onyx.db.models import DocumentSet from onyx.db.models import LLMProvider as LLMProviderModel @@ -124,10 +126,29 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM def fetch_existing_llm_providers( db_session: Session, +) -> list[LLMProviderModel]: + stmt = select(LLMProviderModel) + return list(db_session.scalars(stmt).all()) + + +def fetch_existing_llm_providers_for_user( + db_session: Session, user: User | None = None, ) -> list[LLMProviderModel]: if not user: - return list(db_session.scalars(select(LLMProviderModel)).all()) + if AUTH_TYPE != AuthType.DISABLED: + # User is anonymous + return list( + db_session.scalars( + select(LLMProviderModel).where( + LLMProviderModel.is_public == True # noqa: E712 + ) + ).all() + ) + else: + # If auth is disabled, user has access to all providers + return fetch_existing_llm_providers(db_session) + stmt = select(LLMProviderModel).distinct() user_groups_select = select(User__UserGroup.user_group_id).where( User__UserGroup.user_id == user.id diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index b5b52f590..ad0bc9742 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -10,6 +10,7 @@ from onyx.auth.users import current_admin_user from onyx.auth.users import current_chat_accesssible_user from onyx.db.engine import get_session from onyx.db.llm import fetch_existing_llm_providers +from onyx.db.llm import fetch_existing_llm_providers_for_user from onyx.db.llm import fetch_provider from onyx.db.llm import remove_llm_provider from onyx.db.llm import update_default_provider @@ -195,5 +196,7 @@ def list_llm_provider_basics( ) -> list[LLMProviderDescriptor]: return [ LLMProviderDescriptor.from_model(llm_provider_model) - for llm_provider_model in fetch_existing_llm_providers(db_session, user) + for llm_provider_model in fetch_existing_llm_providers_for_user( + db_session, user + ) ]