mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-04 09:58:32 +02:00
Harden embedding calls
This commit is contained in:
parent
9d7100a287
commit
76b7792e69
@ -2,6 +2,7 @@ import time
|
||||
|
||||
import requests
|
||||
from httpx import HTTPError
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import (
|
||||
@ -72,6 +73,86 @@ class EmbeddingModel:
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
|
||||
def _make_request() -> EmbedResponse:
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
try:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
except Exception:
|
||||
error_detail = response.text
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
|
||||
return EmbedResponse(**response.json())
|
||||
|
||||
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
|
||||
if embed_request.text_type == EmbedTextType.PASSAGE:
|
||||
return retry(tries=3, delay=5)(_make_request)()
|
||||
else:
|
||||
return _make_request()
|
||||
|
||||
def _encode_api_model(
|
||||
self, texts: list[str], text_type: EmbedTextType, batch_size: int
|
||||
) -> list[Embedding]:
|
||||
if not self.provider_type:
|
||||
raise ValueError("Provider type is not set for API embedding")
|
||||
|
||||
embeddings: list[Embedding] = []
|
||||
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
response = self._make_model_server_request(embed_request)
|
||||
embeddings.extend(response.embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _encode_local_model(
|
||||
self,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int,
|
||||
) -> list[Embedding]:
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
embeddings: list[Embedding] = []
|
||||
logger.debug(
|
||||
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
|
||||
)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
|
||||
response = self._make_model_server_request(embed_request)
|
||||
embeddings.extend(response.embeddings)
|
||||
return embeddings
|
||||
|
||||
def encode(
|
||||
self,
|
||||
texts: list[str],
|
||||
@ -99,65 +180,14 @@ class EmbeddingModel:
|
||||
]
|
||||
|
||||
if self.provider_type:
|
||||
text_batches = batch_list(texts, api_embedding_batch_size)
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
return self._encode_api_model(
|
||||
texts=texts, text_type=text_type, batch_size=api_embedding_batch_size
|
||||
)
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
# Batching for local embedding
|
||||
text_batches = batch_list(texts, local_embedding_batch_size)
|
||||
embeddings: list[Embedding] = []
|
||||
logger.debug(
|
||||
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
|
||||
# if no provider, use local model
|
||||
return self._encode_local_model(
|
||||
texts=texts, text_type=text_type, batch_size=local_embedding_batch_size
|
||||
)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(EmbedResponse(**response.json()).embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossEncoderEnsembleModel:
|
||||
|
Loading…
x
Reference in New Issue
Block a user