From 60265361106660578377d34dcf21eabe274ba906 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Tue, 10 Dec 2024 17:33:44 -0800 Subject: [PATCH] Model Server Async (#3386) * need-verify * fix some lib calls * k * tests * k * k * k * Address the comments * fix comment --- backend/model_server/encoders.py | 170 +++++++++------ backend/model_server/management_endpoints.py | 4 +- backend/model_server/utils.py | 47 +++-- backend/requirements/dev.txt | 24 ++- .../tests/unit/model_server/test_embedding.py | 198 ++++++++++++++++++ 5 files changed, 352 insertions(+), 91 deletions(-) create mode 100644 backend/tests/unit/model_server/test_embedding.py diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index c72be9e4a..ef04c0a7f 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -1,4 +1,6 @@ +import asyncio import json +from types import TracebackType from typing import cast from typing import Optional @@ -6,11 +8,11 @@ import httpx import openai import vertexai # type: ignore import voyageai # type: ignore -from cohere import Client as CohereClient +from cohere import AsyncClient as CohereAsyncClient from fastapi import APIRouter from fastapi import HTTPException from google.oauth2 import service_account # type: ignore -from litellm import embedding +from litellm import aembedding from litellm.exceptions import RateLimitError from retry import retry from sentence_transformers import CrossEncoder # type: ignore @@ -63,22 +65,31 @@ class CloudEmbedding: provider: EmbeddingProvider, api_url: str | None = None, api_version: str | None = None, + timeout: int = API_BASED_EMBEDDING_TIMEOUT, ) -> None: self.provider = provider self.api_key = api_key self.api_url = api_url self.api_version = api_version + self.timeout = timeout + self.http_client = httpx.AsyncClient(timeout=timeout) + self._closed = False - def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]: + async def _embed_openai( + self, texts: list[str], model: str | None + ) -> list[Embedding]: if not model: model = DEFAULT_OPENAI_MODEL - client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT) + # Use the OpenAI specific timeout for this one + client = openai.AsyncOpenAI( + api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT + ) final_embeddings: list[Embedding] = [] try: for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN): - response = client.embeddings.create(input=text_batch, model=model) + response = await client.embeddings.create(input=text_batch, model=model) final_embeddings.extend( [embedding.embedding for embedding in response.data] ) @@ -93,19 +104,19 @@ class CloudEmbedding: logger.error(error_string) raise RuntimeError(error_string) - def _embed_cohere( + async def _embed_cohere( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: model = DEFAULT_COHERE_MODEL - client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT) + client = CohereAsyncClient(api_key=self.api_key) final_embeddings: list[Embedding] = [] for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN): # Does not use the same tokenizer as the Danswer API server but it's approximately the same # empirically it's only off by a very few tokens so it's not a big deal - response = client.embed( + response = await client.embed( texts=text_batch, model=model, input_type=embedding_type, @@ -114,26 +125,29 @@ class CloudEmbedding: final_embeddings.extend(cast(list[Embedding], response.embeddings)) return final_embeddings - def _embed_voyage( + async def _embed_voyage( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: model = DEFAULT_VOYAGE_MODEL - client = voyageai.Client( + client = voyageai.AsyncClient( api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT ) - response = client.embed( - texts, + response = await client.embed( + texts=texts, model=model, input_type=embedding_type, truncation=True, ) + return response.embeddings - def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]: - response = embedding( + async def _embed_azure( + self, texts: list[str], model: str | None + ) -> list[Embedding]: + response = await aembedding( model=model, input=texts, timeout=API_BASED_EMBEDDING_TIMEOUT, @@ -142,10 +156,9 @@ class CloudEmbedding: api_version=self.api_version, ) embeddings = [embedding["embedding"] for embedding in response.data] - return embeddings - def _embed_vertex( + async def _embed_vertex( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: @@ -158,7 +171,7 @@ class CloudEmbedding: vertexai.init(project=project_id, credentials=credentials) client = TextEmbeddingModel.from_pretrained(model) - embeddings = client.get_embeddings( + embeddings = await client.get_embeddings_async( [ TextEmbeddingInput( text, @@ -166,11 +179,11 @@ class CloudEmbedding: ) for text in texts ], - auto_truncate=True, # Also this is default + auto_truncate=True, # This is the default ) return [embedding.values for embedding in embeddings] - def _embed_litellm_proxy( + async def _embed_litellm_proxy( self, texts: list[str], model_name: str | None ) -> list[Embedding]: if not model_name: @@ -183,22 +196,20 @@ class CloudEmbedding: {} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"} ) - with httpx.Client() as client: - response = client.post( - self.api_url, - json={ - "model": model_name, - "input": texts, - }, - headers=headers, - timeout=API_BASED_EMBEDDING_TIMEOUT, - ) - response.raise_for_status() - result = response.json() - return [embedding["embedding"] for embedding in result["data"]] + response = await self.http_client.post( + self.api_url, + json={ + "model": model_name, + "input": texts, + }, + headers=headers, + ) + response.raise_for_status() + result = response.json() + return [embedding["embedding"] for embedding in result["data"]] @retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY) - def embed( + async def embed( self, *, texts: list[str], @@ -207,19 +218,19 @@ class CloudEmbedding: deployment_name: str | None = None, ) -> list[Embedding]: if self.provider == EmbeddingProvider.OPENAI: - return self._embed_openai(texts, model_name) + return await self._embed_openai(texts, model_name) elif self.provider == EmbeddingProvider.AZURE: - return self._embed_azure(texts, f"azure/{deployment_name}") + return await self._embed_azure(texts, f"azure/{deployment_name}") elif self.provider == EmbeddingProvider.LITELLM: - return self._embed_litellm_proxy(texts, model_name) + return await self._embed_litellm_proxy(texts, model_name) embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) if self.provider == EmbeddingProvider.COHERE: - return self._embed_cohere(texts, model_name, embedding_type) + return await self._embed_cohere(texts, model_name, embedding_type) elif self.provider == EmbeddingProvider.VOYAGE: - return self._embed_voyage(texts, model_name, embedding_type) + return await self._embed_voyage(texts, model_name, embedding_type) elif self.provider == EmbeddingProvider.GOOGLE: - return self._embed_vertex(texts, model_name, embedding_type) + return await self._embed_vertex(texts, model_name, embedding_type) else: raise ValueError(f"Unsupported provider: {self.provider}") @@ -233,6 +244,30 @@ class CloudEmbedding: logger.debug(f"Creating Embedding instance for provider: {provider}") return CloudEmbedding(api_key, provider, api_url, api_version) + async def aclose(self) -> None: + """Explicitly close the client.""" + if not self._closed: + await self.http_client.aclose() + self._closed = True + + async def __aenter__(self) -> "CloudEmbedding": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.aclose() + + def __del__(self) -> None: + """Finalizer to warn about unclosed clients.""" + if not self._closed: + logger.warning( + "CloudEmbedding was not properly closed. Use 'async with' or call aclose()" + ) + def get_embedding_model( model_name: str, @@ -242,9 +277,6 @@ def get_embedding_model( global _GLOBAL_MODELS_DICT # A dictionary to store models - if _GLOBAL_MODELS_DICT is None: - _GLOBAL_MODELS_DICT = {} - if model_name not in _GLOBAL_MODELS_DICT: logger.notice(f"Loading {model_name}") # Some model architectures that aren't built into the Transformers or Sentence @@ -275,7 +307,7 @@ def get_local_reranking_model( @simple_log_function_time() -def embed_text( +async def embed_text( texts: list[str], text_type: EmbedTextType, model_name: str | None, @@ -311,18 +343,18 @@ def embed_text( "Cloud models take an explicit text type instead." ) - cloud_model = CloudEmbedding( + async with CloudEmbedding( api_key=api_key, provider=provider_type, api_url=api_url, api_version=api_version, - ) - embeddings = cloud_model.embed( - texts=texts, - model_name=model_name, - deployment_name=deployment_name, - text_type=text_type, - ) + ) as cloud_model: + embeddings = await cloud_model.embed( + texts=texts, + model_name=model_name, + deployment_name=deployment_name, + text_type=text_type, + ) if any(embedding is None for embedding in embeddings): error_message = "Embeddings contain None values\n" @@ -338,8 +370,12 @@ def embed_text( local_model = get_embedding_model( model_name=model_name, max_context_length=max_context_length ) - embeddings_vectors = local_model.encode( - prefixed_texts, normalize_embeddings=normalize_embeddings + # Run CPU-bound embedding in a thread pool + embeddings_vectors = await asyncio.get_event_loop().run_in_executor( + None, + lambda: local_model.encode( + prefixed_texts, normalize_embeddings=normalize_embeddings + ), ) embeddings = [ embedding if isinstance(embedding, list) else embedding.tolist() @@ -357,27 +393,31 @@ def embed_text( @simple_log_function_time() -def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]: +async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]: cross_encoder = get_local_reranking_model(model_name) - return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore + # Run CPU-bound reranking in a thread pool + return await asyncio.get_event_loop().run_in_executor( + None, + lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore + ) -def cohere_rerank( +async def cohere_rerank( query: str, docs: list[str], model_name: str, api_key: str ) -> list[float]: - cohere_client = CohereClient(api_key=api_key) - response = cohere_client.rerank(query=query, documents=docs, model=model_name) + cohere_client = CohereAsyncClient(api_key=api_key) + response = await cohere_client.rerank(query=query, documents=docs, model=model_name) results = response.results sorted_results = sorted(results, key=lambda item: item.index) return [result.relevance_score for result in sorted_results] -def litellm_rerank( +async def litellm_rerank( query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None ) -> list[float]: headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} - with httpx.Client() as client: - response = client.post( + async with httpx.AsyncClient() as client: + response = await client.post( api_url, json={ "model": model_name, @@ -411,7 +451,7 @@ async def process_embed_request( else: prefix = None - embeddings = embed_text( + embeddings = await embed_text( texts=embed_request.texts, model_name=embed_request.model_name, deployment_name=embed_request.deployment_name, @@ -451,7 +491,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons try: if rerank_request.provider_type is None: - sim_scores = local_rerank( + sim_scores = await local_rerank( query=rerank_request.query, docs=rerank_request.documents, model_name=rerank_request.model_name, @@ -461,7 +501,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons if rerank_request.api_url is None: raise ValueError("API URL is required for LiteLLM reranking.") - sim_scores = litellm_rerank( + sim_scores = await litellm_rerank( query=rerank_request.query, docs=rerank_request.documents, api_url=rerank_request.api_url, @@ -474,7 +514,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons elif rerank_request.provider_type == RerankerProvider.COHERE: if rerank_request.api_key is None: raise RuntimeError("Cohere Rerank Requires an API Key") - sim_scores = cohere_rerank( + sim_scores = await cohere_rerank( query=rerank_request.query, docs=rerank_request.documents, model_name=rerank_request.model_name, diff --git a/backend/model_server/management_endpoints.py b/backend/model_server/management_endpoints.py index 56640a2fa..4c6387e07 100644 --- a/backend/model_server/management_endpoints.py +++ b/backend/model_server/management_endpoints.py @@ -6,12 +6,12 @@ router = APIRouter(prefix="/api") @router.get("/health") -def healthcheck() -> Response: +async def healthcheck() -> Response: return Response(status_code=200) @router.get("/gpu-status") -def gpu_status() -> dict[str, bool | str]: +async def gpu_status() -> dict[str, bool | str]: if torch.cuda.is_available(): return {"gpu_available": True, "type": "cuda"} elif torch.backends.mps.is_available(): diff --git a/backend/model_server/utils.py b/backend/model_server/utils.py index 0c2d6bac5..86192b031 100644 --- a/backend/model_server/utils.py +++ b/backend/model_server/utils.py @@ -1,3 +1,4 @@ +import asyncio import time from collections.abc import Callable from collections.abc import Generator @@ -21,21 +22,39 @@ def simple_log_function_time( include_args: bool = False, ) -> Callable[[F], F]: def decorator(func: F) -> F: - @wraps(func) - def wrapped_func(*args: Any, **kwargs: Any) -> Any: - start_time = time.time() - result = func(*args, **kwargs) - elapsed_time_str = str(time.time() - start_time) - log_name = func_name or func.__name__ - args_str = f" args={args} kwargs={kwargs}" if include_args else "" - final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" - if debug_only: - logger.debug(final_log) - else: - logger.notice(final_log) + if asyncio.iscoroutinefunction(func): - return result + @wraps(func) + async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any: + start_time = time.time() + result = await func(*args, **kwargs) + elapsed_time_str = str(time.time() - start_time) + log_name = func_name or func.__name__ + args_str = f" args={args} kwargs={kwargs}" if include_args else "" + final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" + if debug_only: + logger.debug(final_log) + else: + logger.notice(final_log) + return result - return cast(F, wrapped_func) + return cast(F, wrapped_async_func) + else: + + @wraps(func) + def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any: + start_time = time.time() + result = func(*args, **kwargs) + elapsed_time_str = str(time.time() - start_time) + log_name = func_name or func.__name__ + args_str = f" args={args} kwargs={kwargs}" if include_args else "" + final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" + if debug_only: + logger.debug(final_log) + else: + logger.notice(final_log) + return result + + return cast(F, wrapped_sync_func) return decorator diff --git a/backend/requirements/dev.txt b/backend/requirements/dev.txt index 27304dbef..a89b8db67 100644 --- a/backend/requirements/dev.txt +++ b/backend/requirements/dev.txt @@ -1,30 +1,34 @@ black==23.3.0 +boto3-stubs[s3]==1.34.133 celery-types==0.19.0 +cohere==5.6.1 +google-cloud-aiplatform==1.58.0 +lxml==5.3.0 +lxml_html_clean==0.2.2 mypy-extensions==1.0.0 mypy==1.8.0 +pandas-stubs==2.2.3.241009 +pandas==2.2.3 pre-commit==3.2.2 +pytest-asyncio==0.22.0 pytest==7.4.4 reorder-python-imports==3.9.0 ruff==0.0.286 -types-PyYAML==6.0.12.11 +sentence-transformers==2.6.1 +trafilatura==1.12.2 types-beautifulsoup4==4.12.0.3 types-html5lib==1.1.11.13 types-oauthlib==3.2.0.9 -types-setuptools==68.0.0.3 -types-Pillow==10.2.0.20240822 types-passlib==1.7.7.20240106 +types-Pillow==10.2.0.20240822 types-psutil==5.9.5.17 types-psycopg2==2.9.21.10 types-python-dateutil==2.8.19.13 types-pytz==2023.3.1.1 +types-PyYAML==6.0.12.11 types-regex==2023.3.23.1 types-requests==2.28.11.17 types-retry==0.9.9.3 +types-setuptools==68.0.0.3 types-urllib3==1.26.25.11 -trafilatura==1.12.2 -lxml==5.3.0 -lxml_html_clean==0.2.2 -boto3-stubs[s3]==1.34.133 -pandas==2.2.3 -pandas-stubs==2.2.3.241009 -cohere==5.6.1 \ No newline at end of file +voyageai==0.2.3 diff --git a/backend/tests/unit/model_server/test_embedding.py b/backend/tests/unit/model_server/test_embedding.py new file mode 100644 index 000000000..6781ab27a --- /dev/null +++ b/backend/tests/unit/model_server/test_embedding.py @@ -0,0 +1,198 @@ +import asyncio +import time +from collections.abc import AsyncGenerator +from typing import Any +from typing import List +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from httpx import AsyncClient +from litellm.exceptions import RateLimitError + +from model_server.encoders import CloudEmbedding +from model_server.encoders import embed_text +from model_server.encoders import local_rerank +from model_server.encoders import process_embed_request +from shared_configs.enums import EmbeddingProvider +from shared_configs.enums import EmbedTextType +from shared_configs.model_server_models import EmbedRequest + + +@pytest.fixture +async def mock_http_client() -> AsyncGenerator[AsyncMock, None]: + with patch("httpx.AsyncClient") as mock: + client = AsyncMock(spec=AsyncClient) + mock.return_value = client + client.post = AsyncMock() + async with client as c: + yield c + + +@pytest.fixture +def sample_embeddings() -> List[List[float]]: + return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + + +@pytest.mark.asyncio +async def test_cloud_embedding_context_manager() -> None: + async with CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) as embedding: + assert not embedding._closed + assert embedding._closed + + +@pytest.mark.asyncio +async def test_cloud_embedding_explicit_close() -> None: + embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) + assert not embedding._closed + await embedding.aclose() + assert embedding._closed + + +@pytest.mark.asyncio +async def test_openai_embedding( + mock_http_client: AsyncMock, sample_embeddings: List[List[float]] +) -> None: + with patch("openai.AsyncOpenAI") as mock_openai: + mock_client = AsyncMock() + mock_openai.return_value = mock_client + + mock_response = MagicMock() + mock_response.data = [MagicMock(embedding=emb) for emb in sample_embeddings] + mock_client.embeddings.create = AsyncMock(return_value=mock_response) + + embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) + result = await embedding._embed_openai( + ["test1", "test2"], "text-embedding-ada-002" + ) + + assert result == sample_embeddings + mock_client.embeddings.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_embed_text_cloud_provider() -> None: + with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed: + mock_embed.return_value = [[0.1, 0.2], [0.3, 0.4]] + mock_embed.side_effect = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]]) + + result = await embed_text( + texts=["test1", "test2"], + text_type=EmbedTextType.QUERY, + model_name="fake-model", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key="fake-key", + provider_type=EmbeddingProvider.OPENAI, + prefix=None, + api_url=None, + api_version=None, + ) + + assert result == [[0.1, 0.2], [0.3, 0.4]] + mock_embed.assert_called_once() + + +@pytest.mark.asyncio +async def test_embed_text_local_model() -> None: + with patch("model_server.encoders.get_embedding_model") as mock_get_model: + mock_model = MagicMock() + mock_model.encode.return_value = [[0.1, 0.2], [0.3, 0.4]] + mock_get_model.return_value = mock_model + + result = await embed_text( + texts=["test1", "test2"], + text_type=EmbedTextType.QUERY, + model_name="fake-local-model", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key=None, + provider_type=None, + prefix=None, + api_url=None, + api_version=None, + ) + + assert result == [[0.1, 0.2], [0.3, 0.4]] + mock_model.encode.assert_called_once() + + +@pytest.mark.asyncio +async def test_local_rerank() -> None: + with patch("model_server.encoders.get_local_reranking_model") as mock_get_model: + mock_model = MagicMock() + mock_array = MagicMock() + mock_array.tolist.return_value = [0.8, 0.6] + mock_model.predict.return_value = mock_array + mock_get_model.return_value = mock_model + + result = await local_rerank( + query="test query", docs=["doc1", "doc2"], model_name="fake-rerank-model" + ) + + assert result == [0.8, 0.6] + mock_model.predict.assert_called_once() + + +@pytest.mark.asyncio +async def test_rate_limit_handling() -> None: + with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed: + mock_embed.side_effect = RateLimitError( + "Rate limit exceeded", llm_provider="openai", model="fake-model" + ) + + with pytest.raises(RateLimitError): + await embed_text( + texts=["test"], + text_type=EmbedTextType.QUERY, + model_name="fake-model", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key="fake-key", + provider_type=EmbeddingProvider.OPENAI, + prefix=None, + api_url=None, + api_version=None, + ) + + +@pytest.mark.asyncio +async def test_concurrent_embeddings() -> None: + def mock_encode(*args: Any, **kwargs: Any) -> List[List[float]]: + time.sleep(5) + return [[0.1, 0.2, 0.3]] + + test_req = EmbedRequest( + texts=["test"], + model_name="'nomic-ai/nomic-embed-text-v1'", + deployment_name=None, + max_context_length=512, + normalize_embeddings=True, + api_key=None, + provider_type=None, + text_type=EmbedTextType.QUERY, + manual_query_prefix=None, + manual_passage_prefix=None, + api_url=None, + api_version=None, + ) + + with patch("model_server.encoders.get_embedding_model") as mock_get_model: + mock_model = MagicMock() + mock_model.encode = mock_encode + mock_get_model.return_value = mock_model + start_time = time.time() + + tasks = [process_embed_request(test_req) for _ in range(5)] + await asyncio.gather(*tasks) + + end_time = time.time() + + # 5 * 5 seconds = 25 seconds, this test ensures that the embeddings are at least yielding the thread + # However, the developer may still introduce unnecessary blocking above the mock and this test will + # still pass as long as it's less than (7 - 5) / 5 seconds + assert end_time - start_time < 7