mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-26 12:23:13 +02:00
* trying out a fix * add ability to manually run model tests * add log dump * check status code, not text? * just the model server * add port mapping to host * pass through more api keys * add azure tests * fix litellm env vars * fix env vars in github workflow * temp disable litellm test --------- Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
148 lines
5.1 KiB
Python
148 lines
5.1 KiB
Python
import os
|
|
|
|
import pytest
|
|
|
|
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
|
from shared_configs.enums import EmbedTextType
|
|
from shared_configs.model_server_models import EmbeddingProvider
|
|
|
|
VALID_SAMPLE = ["hi", "hello my name is bob", "woah there!!!. 😃"]
|
|
VALID_LONG_SAMPLE = ["hi " * 999]
|
|
# openai limit is 2048, cohere is supposed to be 96 but in practice that doesn't
|
|
# seem to be true
|
|
TOO_LONG_SAMPLE = ["a"] * 2500
|
|
|
|
|
|
def _run_embeddings(
|
|
texts: list[str], embedding_model: EmbeddingModel, expected_dim: int
|
|
) -> None:
|
|
for text_type in [EmbedTextType.QUERY, EmbedTextType.PASSAGE]:
|
|
embeddings = embedding_model.encode(texts, text_type)
|
|
assert len(embeddings) == len(texts)
|
|
assert len(embeddings[0]) == expected_dim
|
|
|
|
|
|
@pytest.fixture
|
|
def openai_embedding_model() -> EmbeddingModel:
|
|
return EmbeddingModel(
|
|
server_host="localhost",
|
|
server_port=9000,
|
|
model_name="text-embedding-3-small",
|
|
normalize=True,
|
|
query_prefix=None,
|
|
passage_prefix=None,
|
|
api_key=os.getenv("OPENAI_API_KEY"),
|
|
provider_type=EmbeddingProvider.OPENAI,
|
|
api_url=None,
|
|
)
|
|
|
|
|
|
def test_openai_embedding(openai_embedding_model: EmbeddingModel) -> None:
|
|
_run_embeddings(VALID_SAMPLE, openai_embedding_model, 1536)
|
|
_run_embeddings(TOO_LONG_SAMPLE, openai_embedding_model, 1536)
|
|
|
|
|
|
@pytest.fixture
|
|
def cohere_embedding_model() -> EmbeddingModel:
|
|
return EmbeddingModel(
|
|
server_host="localhost",
|
|
server_port=9000,
|
|
model_name="embed-english-light-v3.0",
|
|
normalize=True,
|
|
query_prefix=None,
|
|
passage_prefix=None,
|
|
api_key=os.getenv("COHERE_API_KEY"),
|
|
provider_type=EmbeddingProvider.COHERE,
|
|
api_url=None,
|
|
)
|
|
|
|
|
|
def test_cohere_embedding(cohere_embedding_model: EmbeddingModel) -> None:
|
|
_run_embeddings(VALID_SAMPLE, cohere_embedding_model, 384)
|
|
_run_embeddings(TOO_LONG_SAMPLE, cohere_embedding_model, 384)
|
|
|
|
|
|
@pytest.fixture
|
|
def litellm_embedding_model() -> EmbeddingModel:
|
|
return EmbeddingModel(
|
|
server_host="localhost",
|
|
server_port=9000,
|
|
model_name="text-embedding-3-small",
|
|
normalize=True,
|
|
query_prefix=None,
|
|
passage_prefix=None,
|
|
api_key=os.getenv("LITELLM_API_KEY"),
|
|
provider_type=EmbeddingProvider.LITELLM,
|
|
api_url=os.getenv("LITELLM_API_URL"),
|
|
)
|
|
|
|
|
|
@pytest.mark.skip(reason="re-enable when we can get the correct litellm key and url")
|
|
def test_litellm_embedding(litellm_embedding_model: EmbeddingModel) -> None:
|
|
_run_embeddings(VALID_SAMPLE, litellm_embedding_model, 1536)
|
|
_run_embeddings(TOO_LONG_SAMPLE, litellm_embedding_model, 1536)
|
|
|
|
|
|
@pytest.fixture
|
|
def local_nomic_embedding_model() -> EmbeddingModel:
|
|
return EmbeddingModel(
|
|
server_host="localhost",
|
|
server_port=9000,
|
|
model_name="nomic-ai/nomic-embed-text-v1",
|
|
normalize=True,
|
|
query_prefix="search_query: ",
|
|
passage_prefix="search_document: ",
|
|
api_key=None,
|
|
provider_type=None,
|
|
api_url=None,
|
|
)
|
|
|
|
|
|
def test_local_nomic_embedding(local_nomic_embedding_model: EmbeddingModel) -> None:
|
|
_run_embeddings(VALID_SAMPLE, local_nomic_embedding_model, 768)
|
|
_run_embeddings(TOO_LONG_SAMPLE, local_nomic_embedding_model, 768)
|
|
|
|
|
|
@pytest.fixture
|
|
def azure_embedding_model() -> EmbeddingModel:
|
|
return EmbeddingModel(
|
|
server_host="localhost",
|
|
server_port=9000,
|
|
model_name="text-embedding-3-large",
|
|
normalize=True,
|
|
query_prefix=None,
|
|
passage_prefix=None,
|
|
api_key=os.getenv("AZURE_API_KEY"),
|
|
provider_type=EmbeddingProvider.AZURE,
|
|
api_url=os.getenv("AZURE_API_URL"),
|
|
)
|
|
|
|
|
|
def test_azure_embedding(azure_embedding_model: EmbeddingModel) -> None:
|
|
_run_embeddings(VALID_SAMPLE, azure_embedding_model, 1536)
|
|
_run_embeddings(TOO_LONG_SAMPLE, azure_embedding_model, 1536)
|
|
|
|
|
|
# NOTE (chris): this test doesn't work, and I do not know why
|
|
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
|
|
# """NOTE: this test relies on a very low rate limit for the Azure API +
|
|
# this test only being run once in a 1 minute window"""
|
|
# # VALID_LONG_SAMPLE is 999 tokens, so the second call should run into rate
|
|
# # limits assuming the limit is 1000 tokens per minute
|
|
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
|
# assert len(result) == 1
|
|
# assert len(result[0]) == 1536
|
|
|
|
# # this should fail
|
|
# with pytest.raises(ModelServerRateLimitError):
|
|
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
|
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
|
# azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.QUERY)
|
|
|
|
# # this should succeed, since passage requests retry up to 10 times
|
|
# start = time.time()
|
|
# result = azure_embedding_model.encode(VALID_LONG_SAMPLE, EmbedTextType.PASSAGE)
|
|
# assert len(result) == 1
|
|
# assert len(result[0]) == 1536
|
|
# assert time.time() - start > 30 # make sure we waited, even though we hit rate limits
|