mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 21:32:36 +01:00
730 lines
26 KiB
Python
730 lines
26 KiB
Python
import asyncio
|
|
import json
|
|
import time
|
|
from types import TracebackType
|
|
from typing import cast
|
|
from typing import Optional
|
|
|
|
import aioboto3 # type: ignore
|
|
import httpx
|
|
import openai
|
|
import vertexai # type: ignore
|
|
import voyageai # type: ignore
|
|
from cohere import AsyncClient as CohereAsyncClient
|
|
from fastapi import APIRouter
|
|
from fastapi import HTTPException
|
|
from fastapi import Request
|
|
from google.oauth2 import service_account # type: ignore
|
|
from litellm import aembedding
|
|
from litellm.exceptions import RateLimitError
|
|
from retry import retry
|
|
from sentence_transformers import CrossEncoder # type: ignore
|
|
from sentence_transformers import SentenceTransformer # type: ignore
|
|
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
|
from vertexai.language_models import TextEmbeddingModel # type: ignore
|
|
|
|
from model_server.constants import DEFAULT_COHERE_MODEL
|
|
from model_server.constants import DEFAULT_OPENAI_MODEL
|
|
from model_server.constants import DEFAULT_VERTEX_MODEL
|
|
from model_server.constants import DEFAULT_VOYAGE_MODEL
|
|
from model_server.constants import EmbeddingModelTextType
|
|
from model_server.constants import EmbeddingProvider
|
|
from model_server.utils import pass_aws_key
|
|
from model_server.utils import simple_log_function_time
|
|
from onyx.utils.logger import setup_logger
|
|
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
|
|
from shared_configs.configs import INDEXING_ONLY
|
|
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
|
|
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
|
from shared_configs.enums import EmbedTextType
|
|
from shared_configs.enums import RerankerProvider
|
|
from shared_configs.model_server_models import Embedding
|
|
from shared_configs.model_server_models import EmbedRequest
|
|
from shared_configs.model_server_models import EmbedResponse
|
|
from shared_configs.model_server_models import RerankRequest
|
|
from shared_configs.model_server_models import RerankResponse
|
|
from shared_configs.utils import batch_list
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
router = APIRouter(prefix="/encoder")
|
|
|
|
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
|
_RERANK_MODEL: Optional["CrossEncoder"] = None
|
|
|
|
# If we are not only indexing, dont want retry very long
|
|
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
|
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
|
|
|
# OpenAI only allows 2048 embeddings to be computed at once
|
|
_OPENAI_MAX_INPUT_LEN = 2048
|
|
# Cohere allows up to 96 embeddings in a single embedding calling
|
|
_COHERE_MAX_INPUT_LEN = 96
|
|
|
|
# Authentication error string constants
|
|
_AUTH_ERROR_401 = "401"
|
|
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
|
|
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
|
|
_AUTH_ERROR_PERMISSION = "permission"
|
|
|
|
|
|
def is_authentication_error(error: Exception) -> bool:
|
|
"""Check if an exception is related to authentication issues.
|
|
|
|
Args:
|
|
error: The exception to check
|
|
|
|
Returns:
|
|
bool: True if the error appears to be authentication-related
|
|
"""
|
|
error_str = str(error).lower()
|
|
return (
|
|
_AUTH_ERROR_401 in error_str
|
|
or _AUTH_ERROR_UNAUTHORIZED in error_str
|
|
or _AUTH_ERROR_INVALID_API_KEY in error_str
|
|
or _AUTH_ERROR_PERMISSION in error_str
|
|
)
|
|
|
|
|
|
def format_embedding_error(
|
|
error: Exception,
|
|
service_name: str,
|
|
model: str | None,
|
|
provider: EmbeddingProvider,
|
|
status_code: int | None = None,
|
|
) -> str:
|
|
"""
|
|
Format a standardized error string for embedding errors.
|
|
"""
|
|
detail = f"Status {status_code}" if status_code else f"{type(error)}"
|
|
|
|
return (
|
|
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
|
|
f"Model: {model} "
|
|
f"Provider: {provider} "
|
|
f"Exception: {error}"
|
|
)
|
|
|
|
|
|
# Custom exception for authentication errors
|
|
class AuthenticationError(Exception):
|
|
"""Raised when authentication fails with a provider."""
|
|
|
|
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
|
|
self.provider = provider
|
|
self.message = message
|
|
super().__init__(f"{provider} authentication failed: {message}")
|
|
|
|
|
|
class CloudEmbedding:
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
provider: EmbeddingProvider,
|
|
api_url: str | None = None,
|
|
api_version: str | None = None,
|
|
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
|
|
) -> None:
|
|
self.provider = provider
|
|
self.api_key = api_key
|
|
self.api_url = api_url
|
|
self.api_version = api_version
|
|
self.timeout = timeout
|
|
self.http_client = httpx.AsyncClient(timeout=timeout)
|
|
self._closed = False
|
|
|
|
async def _embed_openai(
|
|
self, texts: list[str], model: str | None, reduced_dimension: int | None
|
|
) -> list[Embedding]:
|
|
if not model:
|
|
model = DEFAULT_OPENAI_MODEL
|
|
|
|
# Use the OpenAI specific timeout for this one
|
|
client = openai.AsyncOpenAI(
|
|
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
|
|
)
|
|
|
|
final_embeddings: list[Embedding] = []
|
|
|
|
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
|
response = await client.embeddings.create(
|
|
input=text_batch,
|
|
model=model,
|
|
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
|
)
|
|
final_embeddings.extend(
|
|
[embedding.embedding for embedding in response.data]
|
|
)
|
|
return final_embeddings
|
|
|
|
async def _embed_cohere(
|
|
self, texts: list[str], model: str | None, embedding_type: str
|
|
) -> list[Embedding]:
|
|
if not model:
|
|
model = DEFAULT_COHERE_MODEL
|
|
|
|
client = CohereAsyncClient(api_key=self.api_key)
|
|
|
|
final_embeddings: list[Embedding] = []
|
|
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
|
# Does not use the same tokenizer as the Onyx API server but it's approximately the same
|
|
# empirically it's only off by a very few tokens so it's not a big deal
|
|
response = await client.embed(
|
|
texts=text_batch,
|
|
model=model,
|
|
input_type=embedding_type,
|
|
truncate="END",
|
|
)
|
|
final_embeddings.extend(cast(list[Embedding], response.embeddings))
|
|
return final_embeddings
|
|
|
|
async def _embed_voyage(
|
|
self, texts: list[str], model: str | None, embedding_type: str
|
|
) -> list[Embedding]:
|
|
if not model:
|
|
model = DEFAULT_VOYAGE_MODEL
|
|
|
|
client = voyageai.AsyncClient(
|
|
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
|
|
)
|
|
|
|
response = await client.embed(
|
|
texts=texts,
|
|
model=model,
|
|
input_type=embedding_type,
|
|
truncation=True,
|
|
)
|
|
return response.embeddings
|
|
|
|
async def _embed_azure(
|
|
self, texts: list[str], model: str | None
|
|
) -> list[Embedding]:
|
|
response = await aembedding(
|
|
model=model,
|
|
input=texts,
|
|
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
|
api_key=self.api_key,
|
|
api_base=self.api_url,
|
|
api_version=self.api_version,
|
|
)
|
|
embeddings = [embedding["embedding"] for embedding in response.data]
|
|
return embeddings
|
|
|
|
async def _embed_vertex(
|
|
self, texts: list[str], model: str | None, embedding_type: str
|
|
) -> list[Embedding]:
|
|
if not model:
|
|
model = DEFAULT_VERTEX_MODEL
|
|
|
|
credentials = service_account.Credentials.from_service_account_info(
|
|
json.loads(self.api_key)
|
|
)
|
|
project_id = json.loads(self.api_key)["project_id"]
|
|
vertexai.init(project=project_id, credentials=credentials)
|
|
client = TextEmbeddingModel.from_pretrained(model)
|
|
|
|
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
|
|
|
|
# Split into batches of 25 texts
|
|
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
|
batches = [
|
|
inputs[i : i + max_texts_per_batch]
|
|
for i in range(0, len(inputs), max_texts_per_batch)
|
|
]
|
|
|
|
# Dispatch all embedding calls asynchronously at once
|
|
tasks = [
|
|
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
|
|
]
|
|
|
|
# Wait for all tasks to complete in parallel
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
return [embedding.values for batch in results for embedding in batch]
|
|
|
|
async def _embed_litellm_proxy(
|
|
self, texts: list[str], model_name: str | None
|
|
) -> list[Embedding]:
|
|
if not model_name:
|
|
raise ValueError("Model name is required for LiteLLM proxy embedding.")
|
|
|
|
if not self.api_url:
|
|
raise ValueError("API URL is required for LiteLLM proxy embedding.")
|
|
|
|
headers = (
|
|
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
|
|
)
|
|
|
|
response = await self.http_client.post(
|
|
self.api_url,
|
|
json={
|
|
"model": model_name,
|
|
"input": texts,
|
|
},
|
|
headers=headers,
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
return [embedding["embedding"] for embedding in result["data"]]
|
|
|
|
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
|
async def embed(
|
|
self,
|
|
*,
|
|
texts: list[str],
|
|
text_type: EmbedTextType,
|
|
model_name: str | None = None,
|
|
deployment_name: str | None = None,
|
|
reduced_dimension: int | None = None,
|
|
) -> list[Embedding]:
|
|
try:
|
|
if self.provider == EmbeddingProvider.OPENAI:
|
|
return await self._embed_openai(texts, model_name, reduced_dimension)
|
|
elif self.provider == EmbeddingProvider.AZURE:
|
|
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
|
elif self.provider == EmbeddingProvider.LITELLM:
|
|
return await self._embed_litellm_proxy(texts, model_name)
|
|
|
|
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
|
if self.provider == EmbeddingProvider.COHERE:
|
|
return await self._embed_cohere(texts, model_name, embedding_type)
|
|
elif self.provider == EmbeddingProvider.VOYAGE:
|
|
return await self._embed_voyage(texts, model_name, embedding_type)
|
|
elif self.provider == EmbeddingProvider.GOOGLE:
|
|
return await self._embed_vertex(texts, model_name, embedding_type)
|
|
else:
|
|
raise ValueError(f"Unsupported provider: {self.provider}")
|
|
except openai.AuthenticationError:
|
|
raise AuthenticationError(provider="OpenAI")
|
|
except httpx.HTTPStatusError as e:
|
|
if e.response.status_code == 401:
|
|
raise AuthenticationError(provider=str(self.provider))
|
|
|
|
error_string = format_embedding_error(
|
|
e,
|
|
str(self.provider),
|
|
model_name or deployment_name,
|
|
self.provider,
|
|
status_code=e.response.status_code,
|
|
)
|
|
logger.error(error_string)
|
|
logger.debug(f"Exception texts: {texts}")
|
|
|
|
raise RuntimeError(error_string)
|
|
except Exception as e:
|
|
if is_authentication_error(e):
|
|
raise AuthenticationError(provider=str(self.provider))
|
|
|
|
error_string = format_embedding_error(
|
|
e, str(self.provider), model_name or deployment_name, self.provider
|
|
)
|
|
logger.error(error_string)
|
|
logger.debug(f"Exception texts: {texts}")
|
|
|
|
raise RuntimeError(error_string)
|
|
|
|
@staticmethod
|
|
def create(
|
|
api_key: str,
|
|
provider: EmbeddingProvider,
|
|
api_url: str | None = None,
|
|
api_version: str | None = None,
|
|
) -> "CloudEmbedding":
|
|
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
|
return CloudEmbedding(api_key, provider, api_url, api_version)
|
|
|
|
async def aclose(self) -> None:
|
|
"""Explicitly close the client."""
|
|
if not self._closed:
|
|
await self.http_client.aclose()
|
|
self._closed = True
|
|
|
|
async def __aenter__(self) -> "CloudEmbedding":
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
await self.aclose()
|
|
|
|
def __del__(self) -> None:
|
|
"""Finalizer to warn about unclosed clients."""
|
|
if not self._closed:
|
|
logger.warning(
|
|
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
|
|
)
|
|
|
|
|
|
def get_embedding_model(
|
|
model_name: str,
|
|
max_context_length: int,
|
|
) -> "SentenceTransformer":
|
|
from sentence_transformers import SentenceTransformer # type: ignore
|
|
|
|
global _GLOBAL_MODELS_DICT # A dictionary to store models
|
|
|
|
if model_name not in _GLOBAL_MODELS_DICT:
|
|
logger.notice(f"Loading {model_name}")
|
|
# Some model architectures that aren't built into the Transformers or Sentence
|
|
# Transformer need to be downloaded to be loaded locally. This does not mean
|
|
# data is sent to remote servers for inference, however the remote code can
|
|
# be fairly arbitrary so only use trusted models
|
|
model = SentenceTransformer(
|
|
model_name_or_path=model_name,
|
|
trust_remote_code=True,
|
|
)
|
|
model.max_seq_length = max_context_length
|
|
_GLOBAL_MODELS_DICT[model_name] = model
|
|
elif max_context_length != _GLOBAL_MODELS_DICT[model_name].max_seq_length:
|
|
_GLOBAL_MODELS_DICT[model_name].max_seq_length = max_context_length
|
|
|
|
return _GLOBAL_MODELS_DICT[model_name]
|
|
|
|
|
|
def get_local_reranking_model(
|
|
model_name: str,
|
|
) -> CrossEncoder:
|
|
global _RERANK_MODEL
|
|
if _RERANK_MODEL is None:
|
|
logger.notice(f"Loading {model_name}")
|
|
model = CrossEncoder(model_name)
|
|
_RERANK_MODEL = model
|
|
return _RERANK_MODEL
|
|
|
|
|
|
@simple_log_function_time()
|
|
async def embed_text(
|
|
texts: list[str],
|
|
text_type: EmbedTextType,
|
|
model_name: str | None,
|
|
deployment_name: str | None,
|
|
max_context_length: int,
|
|
normalize_embeddings: bool,
|
|
api_key: str | None,
|
|
provider_type: EmbeddingProvider | None,
|
|
prefix: str | None,
|
|
api_url: str | None,
|
|
api_version: str | None,
|
|
reduced_dimension: int | None,
|
|
gpu_type: str = "UNKNOWN",
|
|
) -> list[Embedding]:
|
|
if not all(texts):
|
|
logger.error("Empty strings provided for embedding")
|
|
raise ValueError("Empty strings are not allowed for embedding.")
|
|
|
|
if not texts:
|
|
logger.error("No texts provided for embedding")
|
|
raise ValueError("No texts provided for embedding.")
|
|
|
|
start = time.monotonic()
|
|
|
|
total_chars = 0
|
|
for text in texts:
|
|
total_chars += len(text)
|
|
|
|
if provider_type is not None:
|
|
logger.info(
|
|
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
|
|
)
|
|
|
|
if api_key is None:
|
|
logger.error("API key not provided for cloud model")
|
|
raise RuntimeError("API key not provided for cloud model")
|
|
|
|
if prefix:
|
|
logger.warning("Prefix provided for cloud model, which is not supported")
|
|
raise ValueError(
|
|
"Prefix string is not valid for cloud models. "
|
|
"Cloud models take an explicit text type instead."
|
|
)
|
|
|
|
async with CloudEmbedding(
|
|
api_key=api_key,
|
|
provider=provider_type,
|
|
api_url=api_url,
|
|
api_version=api_version,
|
|
) as cloud_model:
|
|
embeddings = await cloud_model.embed(
|
|
texts=texts,
|
|
model_name=model_name,
|
|
deployment_name=deployment_name,
|
|
text_type=text_type,
|
|
reduced_dimension=reduced_dimension,
|
|
)
|
|
|
|
if any(embedding is None for embedding in embeddings):
|
|
error_message = "Embeddings contain None values\n"
|
|
error_message += "Corresponding texts:\n"
|
|
error_message += "\n".join(texts)
|
|
logger.error(error_message)
|
|
raise ValueError(error_message)
|
|
|
|
elapsed = time.monotonic() - start
|
|
logger.info(
|
|
f"event=embedding_provider "
|
|
f"texts={len(texts)} "
|
|
f"chars={total_chars} "
|
|
f"provider={provider_type} "
|
|
f"elapsed={elapsed:.2f}"
|
|
)
|
|
elif model_name is not None:
|
|
logger.info(
|
|
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
|
|
)
|
|
|
|
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
|
|
|
local_model = get_embedding_model(
|
|
model_name=model_name, max_context_length=max_context_length
|
|
)
|
|
# Run CPU-bound embedding in a thread pool
|
|
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda: local_model.encode(
|
|
prefixed_texts, normalize_embeddings=normalize_embeddings
|
|
),
|
|
)
|
|
embeddings = [
|
|
embedding if isinstance(embedding, list) else embedding.tolist()
|
|
for embedding in embeddings_vectors
|
|
]
|
|
|
|
elapsed = time.monotonic() - start
|
|
logger.info(
|
|
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
|
f"with local model {model_name} in {elapsed:.2f}"
|
|
)
|
|
logger.info(
|
|
f"event=embedding_model "
|
|
f"texts={len(texts)} "
|
|
f"chars={total_chars} "
|
|
f"model={model_name} "
|
|
f"gpu={gpu_type} "
|
|
f"elapsed={elapsed:.2f}"
|
|
)
|
|
else:
|
|
logger.error("Neither model name nor provider specified for embedding")
|
|
raise ValueError(
|
|
"Either model name or provider must be provided to run embeddings."
|
|
)
|
|
|
|
return embeddings
|
|
|
|
|
|
@simple_log_function_time()
|
|
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
|
cross_encoder = get_local_reranking_model(model_name)
|
|
# Run CPU-bound reranking in a thread pool
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
None,
|
|
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
|
|
)
|
|
|
|
|
|
async def cohere_rerank_api(
|
|
query: str, docs: list[str], model_name: str, api_key: str
|
|
) -> list[float]:
|
|
cohere_client = CohereAsyncClient(api_key=api_key)
|
|
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
|
|
results = response.results
|
|
sorted_results = sorted(results, key=lambda item: item.index)
|
|
return [result.relevance_score for result in sorted_results]
|
|
|
|
|
|
async def cohere_rerank_aws(
|
|
query: str,
|
|
docs: list[str],
|
|
model_name: str,
|
|
region_name: str,
|
|
aws_access_key_id: str,
|
|
aws_secret_access_key: str,
|
|
) -> list[float]:
|
|
session = aioboto3.Session(
|
|
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
|
|
)
|
|
async with session.client(
|
|
"bedrock-runtime", region_name=region_name
|
|
) as bedrock_client:
|
|
body = json.dumps(
|
|
{
|
|
"query": query,
|
|
"documents": docs,
|
|
"api_version": 2,
|
|
}
|
|
)
|
|
# Invoke the Bedrock model asynchronously
|
|
response = await bedrock_client.invoke_model(
|
|
modelId=model_name,
|
|
accept="application/json",
|
|
contentType="application/json",
|
|
body=body,
|
|
)
|
|
|
|
# Read the response asynchronously
|
|
response_body = json.loads(await response["body"].read())
|
|
|
|
# Extract and sort the results
|
|
results = response_body.get("results", [])
|
|
sorted_results = sorted(results, key=lambda item: item["index"])
|
|
|
|
return [result["relevance_score"] for result in sorted_results]
|
|
|
|
|
|
async def litellm_rerank(
|
|
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
|
) -> list[float]:
|
|
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.post(
|
|
api_url,
|
|
json={
|
|
"model": model_name,
|
|
"query": query,
|
|
"documents": docs,
|
|
},
|
|
headers=headers,
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
return [
|
|
item["relevance_score"]
|
|
for item in sorted(result["results"], key=lambda x: x["index"])
|
|
]
|
|
|
|
|
|
@router.post("/bi-encoder-embed")
|
|
async def route_bi_encoder_embed(
|
|
request: Request,
|
|
embed_request: EmbedRequest,
|
|
) -> EmbedResponse:
|
|
return await process_embed_request(embed_request, request.app.state.gpu_type)
|
|
|
|
|
|
async def process_embed_request(
|
|
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
|
) -> EmbedResponse:
|
|
if not embed_request.texts:
|
|
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
|
|
|
if not all(embed_request.texts):
|
|
raise ValueError("Empty strings are not allowed for embedding.")
|
|
|
|
try:
|
|
if embed_request.text_type == EmbedTextType.QUERY:
|
|
prefix = embed_request.manual_query_prefix
|
|
elif embed_request.text_type == EmbedTextType.PASSAGE:
|
|
prefix = embed_request.manual_passage_prefix
|
|
else:
|
|
prefix = None
|
|
|
|
embeddings = await embed_text(
|
|
texts=embed_request.texts,
|
|
model_name=embed_request.model_name,
|
|
deployment_name=embed_request.deployment_name,
|
|
max_context_length=embed_request.max_context_length,
|
|
normalize_embeddings=embed_request.normalize_embeddings,
|
|
api_key=embed_request.api_key,
|
|
provider_type=embed_request.provider_type,
|
|
text_type=embed_request.text_type,
|
|
api_url=embed_request.api_url,
|
|
api_version=embed_request.api_version,
|
|
reduced_dimension=embed_request.reduced_dimension,
|
|
prefix=prefix,
|
|
gpu_type=gpu_type,
|
|
)
|
|
return EmbedResponse(embeddings=embeddings)
|
|
except AuthenticationError as e:
|
|
# Handle authentication errors consistently
|
|
logger.error(f"Authentication error: {e.provider}")
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail=f"Authentication failed: {e.message}",
|
|
)
|
|
except RateLimitError as e:
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=str(e),
|
|
)
|
|
except Exception as e:
|
|
logger.exception(
|
|
f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}"
|
|
)
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Error during embedding process: {e}"
|
|
)
|
|
|
|
|
|
@router.post("/cross-encoder-scores")
|
|
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
|
|
"""Cross encoders can be purely black box from the app perspective"""
|
|
if INDEXING_ONLY:
|
|
raise RuntimeError("Indexing model server should not call intent endpoint")
|
|
|
|
if not rerank_request.documents or not rerank_request.query:
|
|
raise HTTPException(
|
|
status_code=400, detail="Missing documents or query for reranking"
|
|
)
|
|
if not all(rerank_request.documents):
|
|
raise ValueError("Empty documents cannot be reranked.")
|
|
|
|
try:
|
|
if rerank_request.provider_type is None:
|
|
sim_scores = await local_rerank(
|
|
query=rerank_request.query,
|
|
docs=rerank_request.documents,
|
|
model_name=rerank_request.model_name,
|
|
)
|
|
return RerankResponse(scores=sim_scores)
|
|
elif rerank_request.provider_type == RerankerProvider.LITELLM:
|
|
if rerank_request.api_url is None:
|
|
raise ValueError("API URL is required for LiteLLM reranking.")
|
|
|
|
sim_scores = await litellm_rerank(
|
|
query=rerank_request.query,
|
|
docs=rerank_request.documents,
|
|
api_url=rerank_request.api_url,
|
|
model_name=rerank_request.model_name,
|
|
api_key=rerank_request.api_key,
|
|
)
|
|
|
|
return RerankResponse(scores=sim_scores)
|
|
|
|
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
|
if rerank_request.api_key is None:
|
|
raise RuntimeError("Cohere Rerank Requires an API Key")
|
|
sim_scores = await cohere_rerank_api(
|
|
query=rerank_request.query,
|
|
docs=rerank_request.documents,
|
|
model_name=rerank_request.model_name,
|
|
api_key=rerank_request.api_key,
|
|
)
|
|
return RerankResponse(scores=sim_scores)
|
|
|
|
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
|
|
if rerank_request.api_key is None:
|
|
raise RuntimeError("Bedrock Rerank Requires an API Key")
|
|
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
|
|
rerank_request.api_key
|
|
)
|
|
sim_scores = await cohere_rerank_aws(
|
|
query=rerank_request.query,
|
|
docs=rerank_request.documents,
|
|
model_name=rerank_request.model_name,
|
|
region_name=aws_region,
|
|
aws_access_key_id=aws_access_key_id,
|
|
aws_secret_access_key=aws_secret_access_key,
|
|
)
|
|
return RerankResponse(scores=sim_scores)
|
|
else:
|
|
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Error during reranking process:\n{str(e)}")
|
|
raise HTTPException(
|
|
status_code=500, detail="Failed to run Cross-Encoder reranking"
|
|
)
|