Model Server Async (#3386)

* need-verify

* fix some lib calls

* k

* tests

* k

* k

* k

* Address the comments

* fix comment
This commit is contained in:
Yuhong Sun 2024-12-10 17:33:44 -08:00 committed by GitHub
parent 056b671cd4
commit 6026536110
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 352 additions and 91 deletions

View File

@ -1,4 +1,6 @@
import asyncio
import json import json
from types import TracebackType
from typing import cast from typing import cast
from typing import Optional from typing import Optional
@ -6,11 +8,11 @@ import httpx
import openai import openai
import vertexai # type: ignore import vertexai # type: ignore
import voyageai # type: ignore import voyageai # type: ignore
from cohere import Client as CohereClient from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import HTTPException from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore from google.oauth2 import service_account # type: ignore
from litellm import embedding from litellm import aembedding
from litellm.exceptions import RateLimitError from litellm.exceptions import RateLimitError
from retry import retry from retry import retry
from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import CrossEncoder # type: ignore
@ -63,22 +65,31 @@ class CloudEmbedding:
provider: EmbeddingProvider, provider: EmbeddingProvider,
api_url: str | None = None, api_url: str | None = None,
api_version: str | None = None, api_version: str | None = None,
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
) -> None: ) -> None:
self.provider = provider self.provider = provider
self.api_key = api_key self.api_key = api_key
self.api_url = api_url self.api_url = api_url
self.api_version = api_version self.api_version = api_version
self.timeout = timeout
self.http_client = httpx.AsyncClient(timeout=timeout)
self._closed = False
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: async def _embed_openai(
self, texts: list[str], model: str | None
) -> list[Embedding]:
if not model: if not model:
model = DEFAULT_OPENAI_MODEL model = DEFAULT_OPENAI_MODEL
client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT) # Use the OpenAI specific timeout for this one
client = openai.AsyncOpenAI(
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
)
final_embeddings: list[Embedding] = [] final_embeddings: list[Embedding] = []
try: try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN): for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = client.embeddings.create(input=text_batch, model=model) response = await client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend( final_embeddings.extend(
[embedding.embedding for embedding in response.data] [embedding.embedding for embedding in response.data]
) )
@ -93,19 +104,19 @@ class CloudEmbedding:
logger.error(error_string) logger.error(error_string)
raise RuntimeError(error_string) raise RuntimeError(error_string)
def _embed_cohere( async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]: ) -> list[Embedding]:
if not model: if not model:
model = DEFAULT_COHERE_MODEL model = DEFAULT_COHERE_MODEL
client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT) client = CohereAsyncClient(api_key=self.api_key)
final_embeddings: list[Embedding] = [] final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN): for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Danswer API server but it's approximately the same # Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal # empirically it's only off by a very few tokens so it's not a big deal
response = client.embed( response = await client.embed(
texts=text_batch, texts=text_batch,
model=model, model=model,
input_type=embedding_type, input_type=embedding_type,
@ -114,26 +125,29 @@ class CloudEmbedding:
final_embeddings.extend(cast(list[Embedding], response.embeddings)) final_embeddings.extend(cast(list[Embedding], response.embeddings))
return final_embeddings return final_embeddings
def _embed_voyage( async def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]: ) -> list[Embedding]:
if not model: if not model:
model = DEFAULT_VOYAGE_MODEL model = DEFAULT_VOYAGE_MODEL
client = voyageai.Client( client = voyageai.AsyncClient(
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
) )
response = client.embed( response = await client.embed(
texts, texts=texts,
model=model, model=model,
input_type=embedding_type, input_type=embedding_type,
truncation=True, truncation=True,
) )
return response.embeddings return response.embeddings
def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]: async def _embed_azure(
response = embedding( self, texts: list[str], model: str | None
) -> list[Embedding]:
response = await aembedding(
model=model, model=model,
input=texts, input=texts,
timeout=API_BASED_EMBEDDING_TIMEOUT, timeout=API_BASED_EMBEDDING_TIMEOUT,
@ -142,10 +156,9 @@ class CloudEmbedding:
api_version=self.api_version, api_version=self.api_version,
) )
embeddings = [embedding["embedding"] for embedding in response.data] embeddings = [embedding["embedding"] for embedding in response.data]
return embeddings return embeddings
def _embed_vertex( async def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]: ) -> list[Embedding]:
if not model: if not model:
@ -158,7 +171,7 @@ class CloudEmbedding:
vertexai.init(project=project_id, credentials=credentials) vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model) client = TextEmbeddingModel.from_pretrained(model)
embeddings = client.get_embeddings( embeddings = await client.get_embeddings_async(
[ [
TextEmbeddingInput( TextEmbeddingInput(
text, text,
@ -166,11 +179,11 @@ class CloudEmbedding:
) )
for text in texts for text in texts
], ],
auto_truncate=True, # Also this is default auto_truncate=True, # This is the default
) )
return [embedding.values for embedding in embeddings] return [embedding.values for embedding in embeddings]
def _embed_litellm_proxy( async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None self, texts: list[str], model_name: str | None
) -> list[Embedding]: ) -> list[Embedding]:
if not model_name: if not model_name:
@ -183,22 +196,20 @@ class CloudEmbedding:
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"} {} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
) )
with httpx.Client() as client: response = await self.http_client.post(
response = client.post( self.api_url,
self.api_url, json={
json={ "model": model_name,
"model": model_name, "input": texts,
"input": texts, },
}, headers=headers,
headers=headers, )
timeout=API_BASED_EMBEDDING_TIMEOUT, response.raise_for_status()
) result = response.json()
response.raise_for_status() return [embedding["embedding"] for embedding in result["data"]]
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY) @retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
def embed( async def embed(
self, self,
*, *,
texts: list[str], texts: list[str],
@ -207,19 +218,19 @@ class CloudEmbedding:
deployment_name: str | None = None, deployment_name: str | None = None,
) -> list[Embedding]: ) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI: if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name) return await self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE: elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}") return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM: elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name) return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE: if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type) return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE: elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type) return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE: elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type) return await self._embed_vertex(texts, model_name, embedding_type)
else: else:
raise ValueError(f"Unsupported provider: {self.provider}") raise ValueError(f"Unsupported provider: {self.provider}")
@ -233,6 +244,30 @@ class CloudEmbedding:
logger.debug(f"Creating Embedding instance for provider: {provider}") logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, api_url, api_version) return CloudEmbedding(api_key, provider, api_url, api_version)
async def aclose(self) -> None:
"""Explicitly close the client."""
if not self._closed:
await self.http_client.aclose()
self._closed = True
async def __aenter__(self) -> "CloudEmbedding":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
def __del__(self) -> None:
"""Finalizer to warn about unclosed clients."""
if not self._closed:
logger.warning(
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
)
def get_embedding_model( def get_embedding_model(
model_name: str, model_name: str,
@ -242,9 +277,6 @@ def get_embedding_model(
global _GLOBAL_MODELS_DICT # A dictionary to store models global _GLOBAL_MODELS_DICT # A dictionary to store models
if _GLOBAL_MODELS_DICT is None:
_GLOBAL_MODELS_DICT = {}
if model_name not in _GLOBAL_MODELS_DICT: if model_name not in _GLOBAL_MODELS_DICT:
logger.notice(f"Loading {model_name}") logger.notice(f"Loading {model_name}")
# Some model architectures that aren't built into the Transformers or Sentence # Some model architectures that aren't built into the Transformers or Sentence
@ -275,7 +307,7 @@ def get_local_reranking_model(
@simple_log_function_time() @simple_log_function_time()
def embed_text( async def embed_text(
texts: list[str], texts: list[str],
text_type: EmbedTextType, text_type: EmbedTextType,
model_name: str | None, model_name: str | None,
@ -311,18 +343,18 @@ def embed_text(
"Cloud models take an explicit text type instead." "Cloud models take an explicit text type instead."
) )
cloud_model = CloudEmbedding( async with CloudEmbedding(
api_key=api_key, api_key=api_key,
provider=provider_type, provider=provider_type,
api_url=api_url, api_url=api_url,
api_version=api_version, api_version=api_version,
) ) as cloud_model:
embeddings = cloud_model.embed( embeddings = await cloud_model.embed(
texts=texts, texts=texts,
model_name=model_name, model_name=model_name,
deployment_name=deployment_name, deployment_name=deployment_name,
text_type=text_type, text_type=text_type,
) )
if any(embedding is None for embedding in embeddings): if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n" error_message = "Embeddings contain None values\n"
@ -338,8 +370,12 @@ def embed_text(
local_model = get_embedding_model( local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length model_name=model_name, max_context_length=max_context_length
) )
embeddings_vectors = local_model.encode( # Run CPU-bound embedding in a thread pool
prefixed_texts, normalize_embeddings=normalize_embeddings embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
None,
lambda: local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
),
) )
embeddings = [ embeddings = [
embedding if isinstance(embedding, list) else embedding.tolist() embedding if isinstance(embedding, list) else embedding.tolist()
@ -357,27 +393,31 @@ def embed_text(
@simple_log_function_time() @simple_log_function_time()
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]: async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
cross_encoder = get_local_reranking_model(model_name) cross_encoder = get_local_reranking_model(model_name)
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore # Run CPU-bound reranking in a thread pool
return await asyncio.get_event_loop().run_in_executor(
None,
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
)
def cohere_rerank( async def cohere_rerank(
query: str, docs: list[str], model_name: str, api_key: str query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]: ) -> list[float]:
cohere_client = CohereClient(api_key=api_key) cohere_client = CohereAsyncClient(api_key=api_key)
response = cohere_client.rerank(query=query, documents=docs, model=model_name) response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results results = response.results
sorted_results = sorted(results, key=lambda item: item.index) sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results] return [result.relevance_score for result in sorted_results]
def litellm_rerank( async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]: ) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
with httpx.Client() as client: async with httpx.AsyncClient() as client:
response = client.post( response = await client.post(
api_url, api_url,
json={ json={
"model": model_name, "model": model_name,
@ -411,7 +451,7 @@ async def process_embed_request(
else: else:
prefix = None prefix = None
embeddings = embed_text( embeddings = await embed_text(
texts=embed_request.texts, texts=embed_request.texts,
model_name=embed_request.model_name, model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name, deployment_name=embed_request.deployment_name,
@ -451,7 +491,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
try: try:
if rerank_request.provider_type is None: if rerank_request.provider_type is None:
sim_scores = local_rerank( sim_scores = await local_rerank(
query=rerank_request.query, query=rerank_request.query,
docs=rerank_request.documents, docs=rerank_request.documents,
model_name=rerank_request.model_name, model_name=rerank_request.model_name,
@ -461,7 +501,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
if rerank_request.api_url is None: if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.") raise ValueError("API URL is required for LiteLLM reranking.")
sim_scores = litellm_rerank( sim_scores = await litellm_rerank(
query=rerank_request.query, query=rerank_request.query,
docs=rerank_request.documents, docs=rerank_request.documents,
api_url=rerank_request.api_url, api_url=rerank_request.api_url,
@ -474,7 +514,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
elif rerank_request.provider_type == RerankerProvider.COHERE: elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None: if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key") raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = cohere_rerank( sim_scores = await cohere_rerank(
query=rerank_request.query, query=rerank_request.query,
docs=rerank_request.documents, docs=rerank_request.documents,
model_name=rerank_request.model_name, model_name=rerank_request.model_name,

View File

@ -6,12 +6,12 @@ router = APIRouter(prefix="/api")
@router.get("/health") @router.get("/health")
def healthcheck() -> Response: async def healthcheck() -> Response:
return Response(status_code=200) return Response(status_code=200)
@router.get("/gpu-status") @router.get("/gpu-status")
def gpu_status() -> dict[str, bool | str]: async def gpu_status() -> dict[str, bool | str]:
if torch.cuda.is_available(): if torch.cuda.is_available():
return {"gpu_available": True, "type": "cuda"} return {"gpu_available": True, "type": "cuda"}
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():

View File

@ -1,3 +1,4 @@
import asyncio
import time import time
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Generator from collections.abc import Generator
@ -21,21 +22,39 @@ def simple_log_function_time(
include_args: bool = False, include_args: bool = False,
) -> Callable[[F], F]: ) -> Callable[[F], F]:
def decorator(func: F) -> F: def decorator(func: F) -> F:
@wraps(func) if asyncio.iscoroutinefunction(func):
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
return result @wraps(func)
async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = await func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
return result
return cast(F, wrapped_func) return cast(F, wrapped_async_func)
else:
@wraps(func)
def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
return result
return cast(F, wrapped_sync_func)
return decorator return decorator

View File

@ -1,30 +1,34 @@
black==23.3.0 black==23.3.0
boto3-stubs[s3]==1.34.133
celery-types==0.19.0 celery-types==0.19.0
cohere==5.6.1
google-cloud-aiplatform==1.58.0
lxml==5.3.0
lxml_html_clean==0.2.2
mypy-extensions==1.0.0 mypy-extensions==1.0.0
mypy==1.8.0 mypy==1.8.0
pandas-stubs==2.2.3.241009
pandas==2.2.3
pre-commit==3.2.2 pre-commit==3.2.2
pytest-asyncio==0.22.0
pytest==7.4.4 pytest==7.4.4
reorder-python-imports==3.9.0 reorder-python-imports==3.9.0
ruff==0.0.286 ruff==0.0.286
types-PyYAML==6.0.12.11 sentence-transformers==2.6.1
trafilatura==1.12.2
types-beautifulsoup4==4.12.0.3 types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13 types-html5lib==1.1.11.13
types-oauthlib==3.2.0.9 types-oauthlib==3.2.0.9
types-setuptools==68.0.0.3
types-Pillow==10.2.0.20240822
types-passlib==1.7.7.20240106 types-passlib==1.7.7.20240106
types-Pillow==10.2.0.20240822
types-psutil==5.9.5.17 types-psutil==5.9.5.17
types-psycopg2==2.9.21.10 types-psycopg2==2.9.21.10
types-python-dateutil==2.8.19.13 types-python-dateutil==2.8.19.13
types-pytz==2023.3.1.1 types-pytz==2023.3.1.1
types-PyYAML==6.0.12.11
types-regex==2023.3.23.1 types-regex==2023.3.23.1
types-requests==2.28.11.17 types-requests==2.28.11.17
types-retry==0.9.9.3 types-retry==0.9.9.3
types-setuptools==68.0.0.3
types-urllib3==1.26.25.11 types-urllib3==1.26.25.11
trafilatura==1.12.2 voyageai==0.2.3
lxml==5.3.0
lxml_html_clean==0.2.2
boto3-stubs[s3]==1.34.133
pandas==2.2.3
pandas-stubs==2.2.3.241009
cohere==5.6.1

View File

@ -0,0 +1,198 @@
import asyncio
import time
from collections.abc import AsyncGenerator
from typing import Any
from typing import List
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from httpx import AsyncClient
from litellm.exceptions import RateLimitError
from model_server.encoders import CloudEmbedding
from model_server.encoders import embed_text
from model_server.encoders import local_rerank
from model_server.encoders import process_embed_request
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbedRequest
@pytest.fixture
async def mock_http_client() -> AsyncGenerator[AsyncMock, None]:
with patch("httpx.AsyncClient") as mock:
client = AsyncMock(spec=AsyncClient)
mock.return_value = client
client.post = AsyncMock()
async with client as c:
yield c
@pytest.fixture
def sample_embeddings() -> List[List[float]]:
return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@pytest.mark.asyncio
async def test_cloud_embedding_context_manager() -> None:
async with CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) as embedding:
assert not embedding._closed
assert embedding._closed
@pytest.mark.asyncio
async def test_cloud_embedding_explicit_close() -> None:
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
assert not embedding._closed
await embedding.aclose()
assert embedding._closed
@pytest.mark.asyncio
async def test_openai_embedding(
mock_http_client: AsyncMock, sample_embeddings: List[List[float]]
) -> None:
with patch("openai.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=emb) for emb in sample_embeddings]
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
result = await embedding._embed_openai(
["test1", "test2"], "text-embedding-ada-002"
)
assert result == sample_embeddings
mock_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_embed_text_cloud_provider() -> None:
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
mock_embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_embed.side_effect = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
result = await embed_text(
texts=["test1", "test2"],
text_type=EmbedTextType.QUERY,
model_name="fake-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key="fake-key",
provider_type=EmbeddingProvider.OPENAI,
prefix=None,
api_url=None,
api_version=None,
)
assert result == [[0.1, 0.2], [0.3, 0.4]]
mock_embed.assert_called_once()
@pytest.mark.asyncio
async def test_embed_text_local_model() -> None:
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
mock_model = MagicMock()
mock_model.encode.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_get_model.return_value = mock_model
result = await embed_text(
texts=["test1", "test2"],
text_type=EmbedTextType.QUERY,
model_name="fake-local-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key=None,
provider_type=None,
prefix=None,
api_url=None,
api_version=None,
)
assert result == [[0.1, 0.2], [0.3, 0.4]]
mock_model.encode.assert_called_once()
@pytest.mark.asyncio
async def test_local_rerank() -> None:
with patch("model_server.encoders.get_local_reranking_model") as mock_get_model:
mock_model = MagicMock()
mock_array = MagicMock()
mock_array.tolist.return_value = [0.8, 0.6]
mock_model.predict.return_value = mock_array
mock_get_model.return_value = mock_model
result = await local_rerank(
query="test query", docs=["doc1", "doc2"], model_name="fake-rerank-model"
)
assert result == [0.8, 0.6]
mock_model.predict.assert_called_once()
@pytest.mark.asyncio
async def test_rate_limit_handling() -> None:
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
mock_embed.side_effect = RateLimitError(
"Rate limit exceeded", llm_provider="openai", model="fake-model"
)
with pytest.raises(RateLimitError):
await embed_text(
texts=["test"],
text_type=EmbedTextType.QUERY,
model_name="fake-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key="fake-key",
provider_type=EmbeddingProvider.OPENAI,
prefix=None,
api_url=None,
api_version=None,
)
@pytest.mark.asyncio
async def test_concurrent_embeddings() -> None:
def mock_encode(*args: Any, **kwargs: Any) -> List[List[float]]:
time.sleep(5)
return [[0.1, 0.2, 0.3]]
test_req = EmbedRequest(
texts=["test"],
model_name="'nomic-ai/nomic-embed-text-v1'",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key=None,
provider_type=None,
text_type=EmbedTextType.QUERY,
manual_query_prefix=None,
manual_passage_prefix=None,
api_url=None,
api_version=None,
)
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
mock_model = MagicMock()
mock_model.encode = mock_encode
mock_get_model.return_value = mock_model
start_time = time.time()
tasks = [process_embed_request(test_req) for _ in range(5)]
await asyncio.gather(*tasks)
end_time = time.time()
# 5 * 5 seconds = 25 seconds, this test ensures that the embeddings are at least yielding the thread
# However, the developer may still introduce unnecessary blocking above the mock and this test will
# still pass as long as it's less than (7 - 5) / 5 seconds
assert end_time - start_time < 7