Model Server Async (#3386)

* need-verify

* fix some lib calls

* k

* tests

* k

* k

* k

* Address the comments

* fix comment
This commit is contained in:
Yuhong Sun 2024-12-10 17:33:44 -08:00 committed by GitHub
parent 056b671cd4
commit 6026536110
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 352 additions and 91 deletions

View File

@ -1,4 +1,6 @@
import asyncio
import json
from types import TracebackType
from typing import cast
from typing import Optional
@ -6,11 +8,11 @@ import httpx
import openai
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import Client as CohereClient
from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from litellm import embedding
from litellm import aembedding
from litellm.exceptions import RateLimitError
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
@ -63,22 +65,31 @@ class CloudEmbedding:
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
) -> None:
self.provider = provider
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.timeout = timeout
self.http_client = httpx.AsyncClient(timeout=timeout)
self._closed = False
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
async def _embed_openai(
self, texts: list[str], model: str | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
# Use the OpenAI specific timeout for this one
client = openai.AsyncOpenAI(
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = client.embeddings.create(input=text_batch, model=model)
response = await client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
@ -93,19 +104,19 @@ class CloudEmbedding:
logger.error(error_string)
raise RuntimeError(error_string)
def _embed_cohere(
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_COHERE_MODEL
client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT)
client = CohereAsyncClient(api_key=self.api_key)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = client.embed(
response = await client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
@ -114,26 +125,29 @@ class CloudEmbedding:
final_embeddings.extend(cast(list[Embedding], response.embeddings))
return final_embeddings
def _embed_voyage(
async def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VOYAGE_MODEL
client = voyageai.Client(
client = voyageai.AsyncClient(
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
)
response = client.embed(
texts,
response = await client.embed(
texts=texts,
model=model,
input_type=embedding_type,
truncation=True,
)
return response.embeddings
def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
response = embedding(
async def _embed_azure(
self, texts: list[str], model: str | None
) -> list[Embedding]:
response = await aembedding(
model=model,
input=texts,
timeout=API_BASED_EMBEDDING_TIMEOUT,
@ -142,10 +156,9 @@ class CloudEmbedding:
api_version=self.api_version,
)
embeddings = [embedding["embedding"] for embedding in response.data]
return embeddings
def _embed_vertex(
async def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
@ -158,7 +171,7 @@ class CloudEmbedding:
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
embeddings = client.get_embeddings(
embeddings = await client.get_embeddings_async(
[
TextEmbeddingInput(
text,
@ -166,11 +179,11 @@ class CloudEmbedding:
)
for text in texts
],
auto_truncate=True, # Also this is default
auto_truncate=True, # This is the default
)
return [embedding.values for embedding in embeddings]
def _embed_litellm_proxy(
async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
) -> list[Embedding]:
if not model_name:
@ -183,22 +196,20 @@ class CloudEmbedding:
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
)
with httpx.Client() as client:
response = client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
timeout=API_BASED_EMBEDDING_TIMEOUT,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
response = await self.http_client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
def embed(
async def embed(
self,
*,
texts: list[str],
@ -207,19 +218,19 @@ class CloudEmbedding:
deployment_name: str | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
return await self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@ -233,6 +244,30 @@ class CloudEmbedding:
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, api_url, api_version)
async def aclose(self) -> None:
"""Explicitly close the client."""
if not self._closed:
await self.http_client.aclose()
self._closed = True
async def __aenter__(self) -> "CloudEmbedding":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
def __del__(self) -> None:
"""Finalizer to warn about unclosed clients."""
if not self._closed:
logger.warning(
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
)
def get_embedding_model(
model_name: str,
@ -242,9 +277,6 @@ def get_embedding_model(
global _GLOBAL_MODELS_DICT # A dictionary to store models
if _GLOBAL_MODELS_DICT is None:
_GLOBAL_MODELS_DICT = {}
if model_name not in _GLOBAL_MODELS_DICT:
logger.notice(f"Loading {model_name}")
# Some model architectures that aren't built into the Transformers or Sentence
@ -275,7 +307,7 @@ def get_local_reranking_model(
@simple_log_function_time()
def embed_text(
async def embed_text(
texts: list[str],
text_type: EmbedTextType,
model_name: str | None,
@ -311,18 +343,18 @@ def embed_text(
"Cloud models take an explicit text type instead."
)
cloud_model = CloudEmbedding(
async with CloudEmbedding(
api_key=api_key,
provider=provider_type,
api_url=api_url,
api_version=api_version,
)
embeddings = cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
)
) as cloud_model:
embeddings = await cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
)
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
@ -338,8 +370,12 @@ def embed_text(
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings_vectors = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
# Run CPU-bound embedding in a thread pool
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
None,
lambda: local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
),
)
embeddings = [
embedding if isinstance(embedding, list) else embedding.tolist()
@ -357,27 +393,31 @@ def embed_text(
@simple_log_function_time()
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
cross_encoder = get_local_reranking_model(model_name)
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
# Run CPU-bound reranking in a thread pool
return await asyncio.get_event_loop().run_in_executor(
None,
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
)
def cohere_rerank(
async def cohere_rerank(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereClient(api_key=api_key)
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
cohere_client = CohereAsyncClient(api_key=api_key)
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results
sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results]
def litellm_rerank(
async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
with httpx.Client() as client:
response = client.post(
async with httpx.AsyncClient() as client:
response = await client.post(
api_url,
json={
"model": model_name,
@ -411,7 +451,7 @@ async def process_embed_request(
else:
prefix = None
embeddings = embed_text(
embeddings = await embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
@ -451,7 +491,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
try:
if rerank_request.provider_type is None:
sim_scores = local_rerank(
sim_scores = await local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
@ -461,7 +501,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")
sim_scores = litellm_rerank(
sim_scores = await litellm_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
api_url=rerank_request.api_url,
@ -474,7 +514,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = cohere_rerank(
sim_scores = await cohere_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,

View File

@ -6,12 +6,12 @@ router = APIRouter(prefix="/api")
@router.get("/health")
def healthcheck() -> Response:
async def healthcheck() -> Response:
return Response(status_code=200)
@router.get("/gpu-status")
def gpu_status() -> dict[str, bool | str]:
async def gpu_status() -> dict[str, bool | str]:
if torch.cuda.is_available():
return {"gpu_available": True, "type": "cuda"}
elif torch.backends.mps.is_available():

View File

@ -1,3 +1,4 @@
import asyncio
import time
from collections.abc import Callable
from collections.abc import Generator
@ -21,21 +22,39 @@ def simple_log_function_time(
include_args: bool = False,
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
if asyncio.iscoroutinefunction(func):
return result
@wraps(func)
async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = await func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
return result
return cast(F, wrapped_func)
return cast(F, wrapped_async_func)
else:
@wraps(func)
def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
return result
return cast(F, wrapped_sync_func)
return decorator

View File

@ -1,30 +1,34 @@
black==23.3.0
boto3-stubs[s3]==1.34.133
celery-types==0.19.0
cohere==5.6.1
google-cloud-aiplatform==1.58.0
lxml==5.3.0
lxml_html_clean==0.2.2
mypy-extensions==1.0.0
mypy==1.8.0
pandas-stubs==2.2.3.241009
pandas==2.2.3
pre-commit==3.2.2
pytest-asyncio==0.22.0
pytest==7.4.4
reorder-python-imports==3.9.0
ruff==0.0.286
types-PyYAML==6.0.12.11
sentence-transformers==2.6.1
trafilatura==1.12.2
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13
types-oauthlib==3.2.0.9
types-setuptools==68.0.0.3
types-Pillow==10.2.0.20240822
types-passlib==1.7.7.20240106
types-Pillow==10.2.0.20240822
types-psutil==5.9.5.17
types-psycopg2==2.9.21.10
types-python-dateutil==2.8.19.13
types-pytz==2023.3.1.1
types-PyYAML==6.0.12.11
types-regex==2023.3.23.1
types-requests==2.28.11.17
types-retry==0.9.9.3
types-setuptools==68.0.0.3
types-urllib3==1.26.25.11
trafilatura==1.12.2
lxml==5.3.0
lxml_html_clean==0.2.2
boto3-stubs[s3]==1.34.133
pandas==2.2.3
pandas-stubs==2.2.3.241009
cohere==5.6.1
voyageai==0.2.3

View File

@ -0,0 +1,198 @@
import asyncio
import time
from collections.abc import AsyncGenerator
from typing import Any
from typing import List
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from httpx import AsyncClient
from litellm.exceptions import RateLimitError
from model_server.encoders import CloudEmbedding
from model_server.encoders import embed_text
from model_server.encoders import local_rerank
from model_server.encoders import process_embed_request
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbedRequest
@pytest.fixture
async def mock_http_client() -> AsyncGenerator[AsyncMock, None]:
with patch("httpx.AsyncClient") as mock:
client = AsyncMock(spec=AsyncClient)
mock.return_value = client
client.post = AsyncMock()
async with client as c:
yield c
@pytest.fixture
def sample_embeddings() -> List[List[float]]:
return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@pytest.mark.asyncio
async def test_cloud_embedding_context_manager() -> None:
async with CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) as embedding:
assert not embedding._closed
assert embedding._closed
@pytest.mark.asyncio
async def test_cloud_embedding_explicit_close() -> None:
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
assert not embedding._closed
await embedding.aclose()
assert embedding._closed
@pytest.mark.asyncio
async def test_openai_embedding(
mock_http_client: AsyncMock, sample_embeddings: List[List[float]]
) -> None:
with patch("openai.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=emb) for emb in sample_embeddings]
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
result = await embedding._embed_openai(
["test1", "test2"], "text-embedding-ada-002"
)
assert result == sample_embeddings
mock_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_embed_text_cloud_provider() -> None:
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
mock_embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_embed.side_effect = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
result = await embed_text(
texts=["test1", "test2"],
text_type=EmbedTextType.QUERY,
model_name="fake-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key="fake-key",
provider_type=EmbeddingProvider.OPENAI,
prefix=None,
api_url=None,
api_version=None,
)
assert result == [[0.1, 0.2], [0.3, 0.4]]
mock_embed.assert_called_once()
@pytest.mark.asyncio
async def test_embed_text_local_model() -> None:
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
mock_model = MagicMock()
mock_model.encode.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_get_model.return_value = mock_model
result = await embed_text(
texts=["test1", "test2"],
text_type=EmbedTextType.QUERY,
model_name="fake-local-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key=None,
provider_type=None,
prefix=None,
api_url=None,
api_version=None,
)
assert result == [[0.1, 0.2], [0.3, 0.4]]
mock_model.encode.assert_called_once()
@pytest.mark.asyncio
async def test_local_rerank() -> None:
with patch("model_server.encoders.get_local_reranking_model") as mock_get_model:
mock_model = MagicMock()
mock_array = MagicMock()
mock_array.tolist.return_value = [0.8, 0.6]
mock_model.predict.return_value = mock_array
mock_get_model.return_value = mock_model
result = await local_rerank(
query="test query", docs=["doc1", "doc2"], model_name="fake-rerank-model"
)
assert result == [0.8, 0.6]
mock_model.predict.assert_called_once()
@pytest.mark.asyncio
async def test_rate_limit_handling() -> None:
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
mock_embed.side_effect = RateLimitError(
"Rate limit exceeded", llm_provider="openai", model="fake-model"
)
with pytest.raises(RateLimitError):
await embed_text(
texts=["test"],
text_type=EmbedTextType.QUERY,
model_name="fake-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key="fake-key",
provider_type=EmbeddingProvider.OPENAI,
prefix=None,
api_url=None,
api_version=None,
)
@pytest.mark.asyncio
async def test_concurrent_embeddings() -> None:
def mock_encode(*args: Any, **kwargs: Any) -> List[List[float]]:
time.sleep(5)
return [[0.1, 0.2, 0.3]]
test_req = EmbedRequest(
texts=["test"],
model_name="'nomic-ai/nomic-embed-text-v1'",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key=None,
provider_type=None,
text_type=EmbedTextType.QUERY,
manual_query_prefix=None,
manual_passage_prefix=None,
api_url=None,
api_version=None,
)
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
mock_model = MagicMock()
mock_model.encode = mock_encode
mock_get_model.return_value = mock_model
start_time = time.time()
tasks = [process_embed_request(test_req) for _ in range(5)]
await asyncio.gather(*tasks)
end_time = time.time()
# 5 * 5 seconds = 25 seconds, this test ensures that the embeddings are at least yielding the thread
# However, the developer may still introduce unnecessary blocking above the mock and this test will
# still pass as long as it's less than (7 - 5) / 5 seconds
assert end_time - start_time < 7