Double check all chat accessible dependencies (#3801)

* double check all chat accessible dependencies

* k

* k

* k

* k

* k

* k
This commit is contained in:
pablonyx 2025-01-28 09:38:32 -08:00 committed by GitHub
parent 2ad86aa9a6
commit 8d62b992ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 2 deletions

View File

@ -3,6 +3,8 @@ from sqlalchemy import or_
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.constants import AuthType
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from onyx.db.models import DocumentSet from onyx.db.models import DocumentSet
from onyx.db.models import LLMProvider as LLMProviderModel from onyx.db.models import LLMProvider as LLMProviderModel
@ -124,10 +126,29 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
def fetch_existing_llm_providers( def fetch_existing_llm_providers(
db_session: Session, db_session: Session,
) -> list[LLMProviderModel]:
stmt = select(LLMProviderModel)
return list(db_session.scalars(stmt).all())
def fetch_existing_llm_providers_for_user(
db_session: Session,
user: User | None = None, user: User | None = None,
) -> list[LLMProviderModel]: ) -> list[LLMProviderModel]:
if not user: if not user:
return list(db_session.scalars(select(LLMProviderModel)).all()) if AUTH_TYPE != AuthType.DISABLED:
# User is anonymous
return list(
db_session.scalars(
select(LLMProviderModel).where(
LLMProviderModel.is_public == True # noqa: E712
)
).all()
)
else:
# If auth is disabled, user has access to all providers
return fetch_existing_llm_providers(db_session)
stmt = select(LLMProviderModel).distinct() stmt = select(LLMProviderModel).distinct()
user_groups_select = select(User__UserGroup.user_group_id).where( user_groups_select = select(User__UserGroup.user_group_id).where(
User__UserGroup.user_id == user.id User__UserGroup.user_id == user.id

View File

@ -10,6 +10,7 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_providers from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import fetch_provider from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider from onyx.db.llm import update_default_provider
@ -195,5 +196,7 @@ def list_llm_provider_basics(
) -> list[LLMProviderDescriptor]: ) -> list[LLMProviderDescriptor]:
return [ return [
LLMProviderDescriptor.from_model(llm_provider_model) LLMProviderDescriptor.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session, user) for llm_provider_model in fetch_existing_llm_providers_for_user(
db_session, user
)
] ]