Harden embedding calls

This commit is contained in:
Weves 2024-08-04 14:47:19 -07:00 committed by Chris Weaver
parent 9d7100a287
commit 76b7792e69

View File

@ -2,6 +2,7 @@ import time
import requests
from httpx import HTTPError
from retry import retry
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import (
@ -72,6 +73,86 @@ class EmbeddingModel:
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> EmbedResponse:
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
return EmbedResponse(**response.json())
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
if embed_request.text_type == EmbedTextType.PASSAGE:
return retry(tries=3, delay=5)(_make_request)()
else:
return _make_request()
def _encode_api_model(
self, texts: list[str], text_type: EmbedTextType, batch_size: int
) -> list[Embedding]:
if not self.provider_type:
raise ValueError("Provider type is not set for API embedding")
embeddings: list[Embedding] = []
text_batches = batch_list(texts, batch_size)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
return embeddings
def _encode_local_model(
self,
texts: list[str],
text_type: EmbedTextType,
batch_size: int,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)
embeddings: list[Embedding] = []
logger.debug(
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
return embeddings
def encode(
self,
texts: list[str],
@ -99,65 +180,14 @@ class EmbeddingModel:
]
if self.provider_type:
text_batches = batch_list(texts, api_embedding_batch_size)
embed_request = EmbedRequest(
model_name=self.model_name,
texts=texts,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
return self._encode_api_model(
texts=texts, text_type=text_type, batch_size=api_embedding_batch_size
)
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
error_detail = response.json().get("detail", str(e))
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
return EmbedResponse(**response.json()).embeddings
# Batching for local embedding
text_batches = batch_list(texts, local_embedding_batch_size)
embeddings: list[Embedding] = []
logger.debug(
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
# if no provider, use local model
return self._encode_local_model(
texts=texts, text_type=text_type, batch_size=local_embedding_batch_size
)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
)
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
error_detail = response.json().get("detail", str(e))
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(EmbedResponse(**response.json()).embeddings)
return embeddings
class CrossEncoderEnsembleModel: