From 76b7792e693cdf5dfe599c54fbe30bb0f8b1e145 Mon Sep 17 00:00:00 2001 From: Weves Date: Sun, 4 Aug 2024 14:47:19 -0700 Subject: [PATCH] Harden embedding calls --- .../search_nlp_models.py | 142 +++++++++++------- 1 file changed, 86 insertions(+), 56 deletions(-) diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index 08c4eb534..1eef5b6dc 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -2,6 +2,7 @@ import time import requests from httpx import HTTPError +from retry import retry from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS from danswer.configs.model_configs import ( @@ -72,6 +73,86 @@ class EmbeddingModel: model_server_url = build_model_server_url(server_host, server_port) self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed" + def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse: + def _make_request() -> EmbedResponse: + response = requests.post( + self.embed_server_endpoint, json=embed_request.dict() + ) + try: + response.raise_for_status() + except requests.HTTPError as e: + try: + error_detail = response.json().get("detail", str(e)) + except Exception: + error_detail = response.text + raise HTTPError(f"HTTP error occurred: {error_detail}") from e + except requests.RequestException as e: + raise HTTPError(f"Request failed: {str(e)}") from e + + return EmbedResponse(**response.json()) + + # only perform retries for the non-realtime embedding of passages (e.g. for indexing) + if embed_request.text_type == EmbedTextType.PASSAGE: + return retry(tries=3, delay=5)(_make_request)() + else: + return _make_request() + + def _encode_api_model( + self, texts: list[str], text_type: EmbedTextType, batch_size: int + ) -> list[Embedding]: + if not self.provider_type: + raise ValueError("Provider type is not set for API embedding") + + embeddings: list[Embedding] = [] + + text_batches = batch_list(texts, batch_size) + for idx, text_batch in enumerate(text_batches, start=1): + logger.debug(f"Encoding batch {idx} of {len(text_batches)}") + embed_request = EmbedRequest( + model_name=self.model_name, + texts=text_batch, + max_context_length=self.max_seq_length, + normalize_embeddings=self.normalize, + api_key=self.api_key, + provider_type=self.provider_type, + text_type=text_type, + manual_query_prefix=self.query_prefix, + manual_passage_prefix=self.passage_prefix, + ) + response = self._make_model_server_request(embed_request) + embeddings.extend(response.embeddings) + + return embeddings + + def _encode_local_model( + self, + texts: list[str], + text_type: EmbedTextType, + batch_size: int, + ) -> list[Embedding]: + text_batches = batch_list(texts, batch_size) + embeddings: list[Embedding] = [] + logger.debug( + f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model" + ) + for idx, text_batch in enumerate(text_batches, start=1): + logger.debug(f"Encoding batch {idx} of {len(text_batches)}") + embed_request = EmbedRequest( + model_name=self.model_name, + texts=text_batch, + max_context_length=self.max_seq_length, + normalize_embeddings=self.normalize, + api_key=self.api_key, + provider_type=self.provider_type, + text_type=text_type, + manual_query_prefix=self.query_prefix, + manual_passage_prefix=self.passage_prefix, + ) + + response = self._make_model_server_request(embed_request) + embeddings.extend(response.embeddings) + return embeddings + def encode( self, texts: list[str], @@ -99,65 +180,14 @@ class EmbeddingModel: ] if self.provider_type: - text_batches = batch_list(texts, api_embedding_batch_size) - embed_request = EmbedRequest( - model_name=self.model_name, - texts=texts, - max_context_length=self.max_seq_length, - normalize_embeddings=self.normalize, - api_key=self.api_key, - provider_type=self.provider_type, - text_type=text_type, - manual_query_prefix=self.query_prefix, - manual_passage_prefix=self.passage_prefix, + return self._encode_api_model( + texts=texts, text_type=text_type, batch_size=api_embedding_batch_size ) - response = requests.post( - self.embed_server_endpoint, json=embed_request.dict() - ) - try: - response.raise_for_status() - except requests.HTTPError as e: - error_detail = response.json().get("detail", str(e)) - raise HTTPError(f"HTTP error occurred: {error_detail}") from e - except requests.RequestException as e: - raise HTTPError(f"Request failed: {str(e)}") from e - return EmbedResponse(**response.json()).embeddings - - # Batching for local embedding - text_batches = batch_list(texts, local_embedding_batch_size) - embeddings: list[Embedding] = [] - logger.debug( - f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model" + # if no provider, use local model + return self._encode_local_model( + texts=texts, text_type=text_type, batch_size=local_embedding_batch_size ) - for idx, text_batch in enumerate(text_batches, start=1): - logger.debug(f"Encoding batch {idx} of {len(text_batches)}") - embed_request = EmbedRequest( - model_name=self.model_name, - texts=text_batch, - max_context_length=self.max_seq_length, - normalize_embeddings=self.normalize, - api_key=self.api_key, - provider_type=self.provider_type, - text_type=text_type, - manual_query_prefix=self.query_prefix, - manual_passage_prefix=self.passage_prefix, - ) - - response = requests.post( - self.embed_server_endpoint, json=embed_request.dict() - ) - try: - response.raise_for_status() - except requests.HTTPError as e: - error_detail = response.json().get("detail", str(e)) - raise HTTPError(f"HTTP error occurred: {error_detail}") from e - except requests.RequestException as e: - raise HTTPError(f"Request failed: {str(e)}") from e - # Normalize embeddings is only configured via model_configs.py, be sure to use right - # value for the set loss - embeddings.extend(EmbedResponse(**response.json()).embeddings) - return embeddings class CrossEncoderEnsembleModel: