mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 09:41:52 +01:00
249 lines
8.4 KiB
Python
249 lines
8.4 KiB
Python
from sqlalchemy import delete
|
|
from sqlalchemy import or_
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.configs.app_configs import AUTH_TYPE
|
|
from onyx.configs.constants import AuthType
|
|
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
|
from onyx.db.models import DocumentSet
|
|
from onyx.db.models import LLMProvider as LLMProviderModel
|
|
from onyx.db.models import LLMProvider__UserGroup
|
|
from onyx.db.models import SearchSettings
|
|
from onyx.db.models import Tool as ToolModel
|
|
from onyx.db.models import User
|
|
from onyx.db.models import User__UserGroup
|
|
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
|
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
|
from onyx.server.manage.llm.models import FullLLMProvider
|
|
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
|
from shared_configs.enums import EmbeddingProvider
|
|
|
|
|
|
def update_group_llm_provider_relationships__no_commit(
|
|
llm_provider_id: int,
|
|
group_ids: list[int] | None,
|
|
db_session: Session,
|
|
) -> None:
|
|
# Delete existing relationships
|
|
db_session.query(LLMProvider__UserGroup).filter(
|
|
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
|
|
).delete(synchronize_session="fetch")
|
|
|
|
# Add new relationships from given group_ids
|
|
if group_ids:
|
|
new_relationships = [
|
|
LLMProvider__UserGroup(
|
|
llm_provider_id=llm_provider_id,
|
|
user_group_id=group_id,
|
|
)
|
|
for group_id in group_ids
|
|
]
|
|
db_session.add_all(new_relationships)
|
|
|
|
|
|
def upsert_cloud_embedding_provider(
|
|
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
|
) -> CloudEmbeddingProvider:
|
|
existing_provider = (
|
|
db_session.query(CloudEmbeddingProviderModel)
|
|
.filter_by(provider_type=provider.provider_type)
|
|
.first()
|
|
)
|
|
if existing_provider:
|
|
for key, value in provider.model_dump().items():
|
|
setattr(existing_provider, key, value)
|
|
else:
|
|
new_provider = CloudEmbeddingProviderModel(**provider.model_dump())
|
|
|
|
db_session.add(new_provider)
|
|
existing_provider = new_provider
|
|
db_session.commit()
|
|
db_session.refresh(existing_provider)
|
|
return CloudEmbeddingProvider.from_request(existing_provider)
|
|
|
|
|
|
def upsert_llm_provider(
|
|
llm_provider: LLMProviderUpsertRequest,
|
|
db_session: Session,
|
|
) -> FullLLMProvider:
|
|
existing_llm_provider = db_session.scalar(
|
|
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
|
)
|
|
|
|
if not existing_llm_provider:
|
|
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
|
db_session.add(existing_llm_provider)
|
|
|
|
existing_llm_provider.provider = llm_provider.provider
|
|
existing_llm_provider.api_key = llm_provider.api_key
|
|
existing_llm_provider.api_base = llm_provider.api_base
|
|
existing_llm_provider.api_version = llm_provider.api_version
|
|
existing_llm_provider.custom_config = llm_provider.custom_config
|
|
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
|
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
|
|
existing_llm_provider.model_names = llm_provider.model_names
|
|
existing_llm_provider.is_public = llm_provider.is_public
|
|
existing_llm_provider.display_model_names = llm_provider.display_model_names
|
|
existing_llm_provider.deployment_name = llm_provider.deployment_name
|
|
|
|
if not existing_llm_provider.id:
|
|
# If its not already in the db, we need to generate an ID by flushing
|
|
db_session.flush()
|
|
|
|
# Make sure the relationship table stays up to date
|
|
update_group_llm_provider_relationships__no_commit(
|
|
llm_provider_id=existing_llm_provider.id,
|
|
group_ids=llm_provider.groups,
|
|
db_session=db_session,
|
|
)
|
|
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
|
|
|
|
db_session.commit()
|
|
|
|
return full_llm_provider
|
|
|
|
|
|
def fetch_existing_embedding_providers(
|
|
db_session: Session,
|
|
) -> list[CloudEmbeddingProviderModel]:
|
|
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
|
|
|
|
|
def fetch_existing_doc_sets(
|
|
db_session: Session, doc_ids: list[int]
|
|
) -> list[DocumentSet]:
|
|
return list(
|
|
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
|
|
)
|
|
|
|
|
|
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
|
|
return list(
|
|
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
|
|
)
|
|
|
|
|
|
def fetch_existing_llm_providers(
|
|
db_session: Session,
|
|
) -> list[LLMProviderModel]:
|
|
stmt = select(LLMProviderModel)
|
|
return list(db_session.scalars(stmt).all())
|
|
|
|
|
|
def fetch_existing_llm_providers_for_user(
|
|
db_session: Session,
|
|
user: User | None = None,
|
|
) -> list[LLMProviderModel]:
|
|
if not user:
|
|
if AUTH_TYPE != AuthType.DISABLED:
|
|
# User is anonymous
|
|
return list(
|
|
db_session.scalars(
|
|
select(LLMProviderModel).where(
|
|
LLMProviderModel.is_public == True # noqa: E712
|
|
)
|
|
).all()
|
|
)
|
|
else:
|
|
# If auth is disabled, user has access to all providers
|
|
return fetch_existing_llm_providers(db_session)
|
|
|
|
stmt = select(LLMProviderModel).distinct()
|
|
user_groups_select = select(User__UserGroup.user_group_id).where(
|
|
User__UserGroup.user_id == user.id
|
|
)
|
|
access_conditions = or_(
|
|
LLMProviderModel.is_public,
|
|
LLMProviderModel.id.in_( # User is part of a group that has access
|
|
select(LLMProvider__UserGroup.llm_provider_id).where(
|
|
LLMProvider__UserGroup.user_group_id.in_(user_groups_select) # type: ignore
|
|
)
|
|
),
|
|
)
|
|
stmt = stmt.where(access_conditions)
|
|
|
|
return list(db_session.scalars(stmt).all())
|
|
|
|
|
|
def fetch_embedding_provider(
|
|
db_session: Session, provider_type: EmbeddingProvider
|
|
) -> CloudEmbeddingProviderModel | None:
|
|
return db_session.scalar(
|
|
select(CloudEmbeddingProviderModel).where(
|
|
CloudEmbeddingProviderModel.provider_type == provider_type
|
|
)
|
|
)
|
|
|
|
|
|
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
|
provider_model = db_session.scalar(
|
|
select(LLMProviderModel).where(
|
|
LLMProviderModel.is_default_provider == True # noqa: E712
|
|
)
|
|
)
|
|
if not provider_model:
|
|
return None
|
|
return FullLLMProvider.from_model(provider_model)
|
|
|
|
|
|
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
|
provider_model = db_session.scalar(
|
|
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
|
)
|
|
if not provider_model:
|
|
return None
|
|
return FullLLMProvider.from_model(provider_model)
|
|
|
|
|
|
def remove_embedding_provider(
|
|
db_session: Session, provider_type: EmbeddingProvider
|
|
) -> None:
|
|
db_session.execute(
|
|
delete(SearchSettings).where(SearchSettings.provider_type == provider_type)
|
|
)
|
|
|
|
# Delete the embedding provider
|
|
db_session.execute(
|
|
delete(CloudEmbeddingProviderModel).where(
|
|
CloudEmbeddingProviderModel.provider_type == provider_type
|
|
)
|
|
)
|
|
|
|
db_session.commit()
|
|
|
|
|
|
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
|
# Remove LLMProvider's dependent relationships
|
|
db_session.execute(
|
|
delete(LLMProvider__UserGroup).where(
|
|
LLMProvider__UserGroup.llm_provider_id == provider_id
|
|
)
|
|
)
|
|
# Remove LLMProvider
|
|
db_session.execute(
|
|
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
|
)
|
|
db_session.commit()
|
|
|
|
|
|
def update_default_provider(provider_id: int, db_session: Session) -> None:
|
|
new_default = db_session.scalar(
|
|
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
|
)
|
|
if not new_default:
|
|
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
|
|
|
|
existing_default = db_session.scalar(
|
|
select(LLMProviderModel).where(
|
|
LLMProviderModel.is_default_provider == True # noqa: E712
|
|
)
|
|
)
|
|
if existing_default:
|
|
existing_default.is_default_provider = None
|
|
# required to ensure that the below does not cause a unique constraint violation
|
|
db_session.flush()
|
|
|
|
new_default.is_default_provider = True
|
|
db_session.commit()
|