mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Model Server Async (#3386)
* need-verify * fix some lib calls * k * tests * k * k * k * Address the comments * fix comment
This commit is contained in:
parent
056b671cd4
commit
6026536110
@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
@ -6,11 +8,11 @@ import httpx
|
||||
import openai
|
||||
import vertexai # type: ignore
|
||||
import voyageai # type: ignore
|
||||
from cohere import Client as CohereClient
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import embedding
|
||||
from litellm import aembedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
@ -63,22 +65,31 @@ class CloudEmbedding:
|
||||
provider: EmbeddingProvider,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.api_version = api_version
|
||||
self.timeout = timeout
|
||||
self.http_client = httpx.AsyncClient(timeout=timeout)
|
||||
self._closed = False
|
||||
|
||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
async def _embed_openai(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
|
||||
# Use the OpenAI specific timeout for this one
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
try:
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = client.embeddings.create(input=text_batch, model=model)
|
||||
response = await client.embeddings.create(input=text_batch, model=model)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
@ -93,19 +104,19 @@ class CloudEmbedding:
|
||||
logger.error(error_string)
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
def _embed_cohere(
|
||||
async def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT)
|
||||
client = CohereAsyncClient(api_key=self.api_key)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = client.embed(
|
||||
response = await client.embed(
|
||||
texts=text_batch,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
@ -114,26 +125,29 @@ class CloudEmbedding:
|
||||
final_embeddings.extend(cast(list[Embedding], response.embeddings))
|
||||
return final_embeddings
|
||||
|
||||
def _embed_voyage(
|
||||
async def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
client = voyageai.Client(
|
||||
client = voyageai.AsyncClient(
|
||||
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
response = client.embed(
|
||||
texts,
|
||||
response = await client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
return response.embeddings
|
||||
|
||||
def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
response = embedding(
|
||||
async def _embed_azure(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[Embedding]:
|
||||
response = await aembedding(
|
||||
model=model,
|
||||
input=texts,
|
||||
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||
@ -142,10 +156,9 @@ class CloudEmbedding:
|
||||
api_version=self.api_version,
|
||||
)
|
||||
embeddings = [embedding["embedding"] for embedding in response.data]
|
||||
|
||||
return embeddings
|
||||
|
||||
def _embed_vertex(
|
||||
async def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
@ -158,7 +171,7 @@ class CloudEmbedding:
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
client = TextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
embeddings = client.get_embeddings(
|
||||
embeddings = await client.get_embeddings_async(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
@ -166,11 +179,11 @@ class CloudEmbedding:
|
||||
)
|
||||
for text in texts
|
||||
],
|
||||
auto_truncate=True, # Also this is default
|
||||
auto_truncate=True, # This is the default
|
||||
)
|
||||
return [embedding.values for embedding in embeddings]
|
||||
|
||||
def _embed_litellm_proxy(
|
||||
async def _embed_litellm_proxy(
|
||||
self, texts: list[str], model_name: str | None
|
||||
) -> list[Embedding]:
|
||||
if not model_name:
|
||||
@ -183,22 +196,20 @@ class CloudEmbedding:
|
||||
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
self.api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
response = await self.http_client.post(
|
||||
self.api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
def embed(
|
||||
async def embed(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
@ -207,19 +218,19 @@ class CloudEmbedding:
|
||||
deployment_name: str | None = None,
|
||||
) -> list[Embedding]:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
return await self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return self._embed_litellm_proxy(texts, model_name)
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
@ -233,6 +244,30 @@ class CloudEmbedding:
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, api_url, api_version)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Explicitly close the client."""
|
||||
if not self._closed:
|
||||
await self.http_client.aclose()
|
||||
self._closed = True
|
||||
|
||||
async def __aenter__(self) -> "CloudEmbedding":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Finalizer to warn about unclosed clients."""
|
||||
if not self._closed:
|
||||
logger.warning(
|
||||
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name: str,
|
||||
@ -242,9 +277,6 @@ def get_embedding_model(
|
||||
|
||||
global _GLOBAL_MODELS_DICT # A dictionary to store models
|
||||
|
||||
if _GLOBAL_MODELS_DICT is None:
|
||||
_GLOBAL_MODELS_DICT = {}
|
||||
|
||||
if model_name not in _GLOBAL_MODELS_DICT:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
# Some model architectures that aren't built into the Transformers or Sentence
|
||||
@ -275,7 +307,7 @@ def get_local_reranking_model(
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def embed_text(
|
||||
async def embed_text(
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None,
|
||||
@ -311,18 +343,18 @@ def embed_text(
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
cloud_model = CloudEmbedding(
|
||||
async with CloudEmbedding(
|
||||
api_key=api_key,
|
||||
provider=provider_type,
|
||||
api_url=api_url,
|
||||
api_version=api_version,
|
||||
)
|
||||
embeddings = cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
) as cloud_model:
|
||||
embeddings = await cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
error_message = "Embeddings contain None values\n"
|
||||
@ -338,8 +370,12 @@ def embed_text(
|
||||
local_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings_vectors = local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
# Run CPU-bound embedding in a thread pool
|
||||
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
),
|
||||
)
|
||||
embeddings = [
|
||||
embedding if isinstance(embedding, list) else embedding.tolist()
|
||||
@ -357,27 +393,31 @@ def embed_text(
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
cross_encoder = get_local_reranking_model(model_name)
|
||||
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
# Run CPU-bound reranking in a thread pool
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def cohere_rerank(
|
||||
async def cohere_rerank(
|
||||
query: str, docs: list[str], model_name: str, api_key: str
|
||||
) -> list[float]:
|
||||
cohere_client = CohereClient(api_key=api_key)
|
||||
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
cohere_client = CohereAsyncClient(api_key=api_key)
|
||||
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
results = response.results
|
||||
sorted_results = sorted(results, key=lambda item: item.index)
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
def litellm_rerank(
|
||||
async def litellm_rerank(
|
||||
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[float]:
|
||||
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
@ -411,7 +451,7 @@ async def process_embed_request(
|
||||
else:
|
||||
prefix = None
|
||||
|
||||
embeddings = embed_text(
|
||||
embeddings = await embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
deployment_name=embed_request.deployment_name,
|
||||
@ -451,7 +491,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
|
||||
try:
|
||||
if rerank_request.provider_type is None:
|
||||
sim_scores = local_rerank(
|
||||
sim_scores = await local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
@ -461,7 +501,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
if rerank_request.api_url is None:
|
||||
raise ValueError("API URL is required for LiteLLM reranking.")
|
||||
|
||||
sim_scores = litellm_rerank(
|
||||
sim_scores = await litellm_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
api_url=rerank_request.api_url,
|
||||
@ -474,7 +514,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
sim_scores = cohere_rerank(
|
||||
sim_scores = await cohere_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
|
@ -6,12 +6,12 @@ router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
def healthcheck() -> Response:
|
||||
async def healthcheck() -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/gpu-status")
|
||||
def gpu_status() -> dict[str, bool | str]:
|
||||
async def gpu_status() -> dict[str, bool | str]:
|
||||
if torch.cuda.is_available():
|
||||
return {"gpu_available": True, "type": "cuda"}
|
||||
elif torch.backends.mps.is_available():
|
||||
|
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
@ -21,21 +22,39 @@ def simple_log_function_time(
|
||||
include_args: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
def decorator(func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.notice(final_log)
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
return result
|
||||
@wraps(func)
|
||||
async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = await func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.notice(final_log)
|
||||
return result
|
||||
|
||||
return cast(F, wrapped_func)
|
||||
return cast(F, wrapped_async_func)
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.notice(final_log)
|
||||
return result
|
||||
|
||||
return cast(F, wrapped_sync_func)
|
||||
|
||||
return decorator
|
||||
|
@ -1,30 +1,34 @@
|
||||
black==23.3.0
|
||||
boto3-stubs[s3]==1.34.133
|
||||
celery-types==0.19.0
|
||||
cohere==5.6.1
|
||||
google-cloud-aiplatform==1.58.0
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
mypy-extensions==1.0.0
|
||||
mypy==1.8.0
|
||||
pandas-stubs==2.2.3.241009
|
||||
pandas==2.2.3
|
||||
pre-commit==3.2.2
|
||||
pytest-asyncio==0.22.0
|
||||
pytest==7.4.4
|
||||
reorder-python-imports==3.9.0
|
||||
ruff==0.0.286
|
||||
types-PyYAML==6.0.12.11
|
||||
sentence-transformers==2.6.1
|
||||
trafilatura==1.12.2
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
types-html5lib==1.1.11.13
|
||||
types-oauthlib==3.2.0.9
|
||||
types-setuptools==68.0.0.3
|
||||
types-Pillow==10.2.0.20240822
|
||||
types-passlib==1.7.7.20240106
|
||||
types-Pillow==10.2.0.20240822
|
||||
types-psutil==5.9.5.17
|
||||
types-psycopg2==2.9.21.10
|
||||
types-python-dateutil==2.8.19.13
|
||||
types-pytz==2023.3.1.1
|
||||
types-PyYAML==6.0.12.11
|
||||
types-regex==2023.3.23.1
|
||||
types-requests==2.28.11.17
|
||||
types-retry==0.9.9.3
|
||||
types-setuptools==68.0.0.3
|
||||
types-urllib3==1.26.25.11
|
||||
trafilatura==1.12.2
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
boto3-stubs[s3]==1.34.133
|
||||
pandas==2.2.3
|
||||
pandas-stubs==2.2.3.241009
|
||||
cohere==5.6.1
|
||||
voyageai==0.2.3
|
||||
|
198
backend/tests/unit/model_server/test_embedding.py
Normal file
198
backend/tests/unit/model_server/test_embedding.py
Normal file
@ -0,0 +1,198 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from litellm.exceptions import RateLimitError
|
||||
|
||||
from model_server.encoders import CloudEmbedding
|
||||
from model_server.encoders import embed_text
|
||||
from model_server.encoders import local_rerank
|
||||
from model_server.encoders import process_embed_request
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_http_client() -> AsyncGenerator[AsyncMock, None]:
|
||||
with patch("httpx.AsyncClient") as mock:
|
||||
client = AsyncMock(spec=AsyncClient)
|
||||
mock.return_value = client
|
||||
client.post = AsyncMock()
|
||||
async with client as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings() -> List[List[float]]:
|
||||
return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cloud_embedding_context_manager() -> None:
|
||||
async with CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) as embedding:
|
||||
assert not embedding._closed
|
||||
assert embedding._closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cloud_embedding_explicit_close() -> None:
|
||||
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
|
||||
assert not embedding._closed
|
||||
await embedding.aclose()
|
||||
assert embedding._closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_embedding(
|
||||
mock_http_client: AsyncMock, sample_embeddings: List[List[float]]
|
||||
) -> None:
|
||||
with patch("openai.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock(embedding=emb) for emb in sample_embeddings]
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
|
||||
result = await embedding._embed_openai(
|
||||
["test1", "test2"], "text-embedding-ada-002"
|
||||
)
|
||||
|
||||
assert result == sample_embeddings
|
||||
mock_client.embeddings.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_text_cloud_provider() -> None:
|
||||
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
|
||||
mock_embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_embed.side_effect = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
|
||||
|
||||
result = await embed_text(
|
||||
texts=["test1", "test2"],
|
||||
text_type=EmbedTextType.QUERY,
|
||||
model_name="fake-model",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key="fake-key",
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_embed.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_text_local_model() -> None:
|
||||
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.encode.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
result = await embed_text(
|
||||
texts=["test1", "test2"],
|
||||
text_type=EmbedTextType.QUERY,
|
||||
model_name="fake-local-model",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_model.encode.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_rerank() -> None:
|
||||
with patch("model_server.encoders.get_local_reranking_model") as mock_get_model:
|
||||
mock_model = MagicMock()
|
||||
mock_array = MagicMock()
|
||||
mock_array.tolist.return_value = [0.8, 0.6]
|
||||
mock_model.predict.return_value = mock_array
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
result = await local_rerank(
|
||||
query="test query", docs=["doc1", "doc2"], model_name="fake-rerank-model"
|
||||
)
|
||||
|
||||
assert result == [0.8, 0.6]
|
||||
mock_model.predict.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_handling() -> None:
|
||||
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
|
||||
mock_embed.side_effect = RateLimitError(
|
||||
"Rate limit exceeded", llm_provider="openai", model="fake-model"
|
||||
)
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
await embed_text(
|
||||
texts=["test"],
|
||||
text_type=EmbedTextType.QUERY,
|
||||
model_name="fake-model",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key="fake-key",
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_embeddings() -> None:
|
||||
def mock_encode(*args: Any, **kwargs: Any) -> List[List[float]]:
|
||||
time.sleep(5)
|
||||
return [[0.1, 0.2, 0.3]]
|
||||
|
||||
test_req = EmbedRequest(
|
||||
texts=["test"],
|
||||
model_name="'nomic-ai/nomic-embed-text-v1'",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
text_type=EmbedTextType.QUERY,
|
||||
manual_query_prefix=None,
|
||||
manual_passage_prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.encode = mock_encode
|
||||
mock_get_model.return_value = mock_model
|
||||
start_time = time.time()
|
||||
|
||||
tasks = [process_embed_request(test_req) for _ in range(5)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# 5 * 5 seconds = 25 seconds, this test ensures that the embeddings are at least yielding the thread
|
||||
# However, the developer may still introduce unnecessary blocking above the mock and this test will
|
||||
# still pass as long as it's less than (7 - 5) / 5 seconds
|
||||
assert end_time - start_time < 7
|
Loading…
x
Reference in New Issue
Block a user