mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-09 14:11:33 +02:00
Update auth for litellm proxy (#2316)
* update for auth * validated embedding model names * remove embedding provider * remove logs * add ability to delete search setting * add abiility to delete models + more streamlined API endpoints * remove upsert * minor typing fix * add connector utils
This commit is contained in:
@ -1,3 +1,5 @@
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -13,6 +15,7 @@ from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
@ -89,6 +92,30 @@ def get_current_db_embedding_provider(
|
||||
return current_embedding_provider
|
||||
|
||||
|
||||
def delete_search_settings(db_session: Session, search_settings_id: int) -> None:
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
|
||||
if current_settings.id == search_settings_id:
|
||||
raise ValueError("Cannot delete currently active search settings")
|
||||
|
||||
# First, delete associated index attempts
|
||||
index_attempts_query = delete(IndexAttempt).where(
|
||||
IndexAttempt.search_settings_id == search_settings_id
|
||||
)
|
||||
db_session.execute(index_attempts_query)
|
||||
|
||||
# Then, delete the search settings
|
||||
search_settings_query = delete(SearchSettings).where(
|
||||
and_(
|
||||
SearchSettings.id == search_settings_id,
|
||||
SearchSettings.status != IndexModelStatus.PRESENT,
|
||||
)
|
||||
)
|
||||
|
||||
db_session.execute(search_settings_query)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_current_search_settings(db_session: Session) -> SearchSettings:
|
||||
query = (
|
||||
select(SearchSettings)
|
||||
|
@ -95,6 +95,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
|
||||
|
||||
class EmbeddingModelDetail(BaseModel):
|
||||
id: int | None = None
|
||||
model_name: str
|
||||
normalize: bool
|
||||
query_prefix: str | None
|
||||
@ -112,6 +113,7 @@ class EmbeddingModelDetail(BaseModel):
|
||||
search_settings: "SearchSettings",
|
||||
) -> "EmbeddingModelDetail":
|
||||
return cls(
|
||||
id=search_settings.id,
|
||||
model_name=search_settings.model_name,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
|
@ -42,10 +42,10 @@ def test_embedding_configuration(
|
||||
api_key=test_llm_request.api_key,
|
||||
api_url=test_llm_request.api_url,
|
||||
provider_type=test_llm_request.provider_type,
|
||||
model_name=test_llm_request.model_name,
|
||||
normalize=False,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
model_name=None,
|
||||
)
|
||||
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
|
||||
|
||||
|
@ -8,10 +8,15 @@ if TYPE_CHECKING:
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
|
||||
|
||||
class SearchSettingsDeleteRequest(BaseModel):
|
||||
search_settings_id: int
|
||||
|
||||
|
||||
class TestEmbeddingRequest(BaseModel):
|
||||
provider_type: EmbeddingProvider
|
||||
api_key: str | None = None
|
||||
api_url: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(BaseModel):
|
||||
|
@ -14,6 +14,7 @@ from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import create_search_settings
|
||||
from danswer.db.search_settings import delete_search_settings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_embedding_provider_from_provider_type
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
@ -23,6 +24,7 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.search.models import SearchSettingsCreationRequest
|
||||
from danswer.server.manage.embedding.models import SearchSettingsDeleteRequest
|
||||
from danswer.server.manage.models import FullModelVersionResponse
|
||||
from danswer.server.models import IdReturn
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -97,6 +99,7 @@ def set_new_search_settings(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=new_search_settings.index_name,
|
||||
)
|
||||
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=search_settings.model_dim,
|
||||
secondary_index_embedding_dim=new_search_settings.model_dim,
|
||||
@ -132,6 +135,21 @@ def cancel_new_embedding(
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/delete-search-settings")
|
||||
def delete_search_settings_endpoint(
|
||||
deletion_request: SearchSettingsDeleteRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
delete_search_settings(
|
||||
db_session=db_session,
|
||||
search_settings_id=deletion_request.search_settings_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/get-current-search-settings")
|
||||
def get_current_search_settings_endpoint(
|
||||
_: User | None = Depends(current_user),
|
||||
|
@ -237,15 +237,18 @@ def get_local_reranking_model(
|
||||
|
||||
|
||||
def embed_with_litellm_proxy(
|
||||
texts: list[str], api_url: str, model: str
|
||||
texts: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[Embedding]:
|
||||
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
api_url,
|
||||
json={
|
||||
"model": model,
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
@ -280,7 +283,12 @@ def embed_text(
|
||||
logger.error("API URL not provided for LiteLLM proxy")
|
||||
raise ValueError("API URL is required for LiteLLM proxy embedding.")
|
||||
try:
|
||||
return embed_with_litellm_proxy(texts, api_url, model_name or "")
|
||||
return embed_with_litellm_proxy(
|
||||
texts=texts,
|
||||
api_url=api_url,
|
||||
model_name=model_name or "",
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during LiteLLM proxy embedding: {str(e)}")
|
||||
raise
|
||||
|
@ -58,6 +58,7 @@ LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice")
|
||||
|
||||
# Fields which should only be set on new search setting
|
||||
PRESERVED_SEARCH_FIELDS = [
|
||||
"id",
|
||||
"provider_type",
|
||||
"api_key",
|
||||
"model_name",
|
||||
|
Reference in New Issue
Block a user