mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-03 18:08:58 +02:00
Harden embedding calls
This commit is contained in:
@@ -2,6 +2,7 @@ import time
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
|
from retry import retry
|
||||||
|
|
||||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||||
from danswer.configs.model_configs import (
|
from danswer.configs.model_configs import (
|
||||||
@@ -72,6 +73,86 @@ class EmbeddingModel:
|
|||||||
model_server_url = build_model_server_url(server_host, server_port)
|
model_server_url = build_model_server_url(server_host, server_port)
|
||||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||||
|
|
||||||
|
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
|
||||||
|
def _make_request() -> EmbedResponse:
|
||||||
|
response = requests.post(
|
||||||
|
self.embed_server_endpoint, json=embed_request.dict()
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
try:
|
||||||
|
error_detail = response.json().get("detail", str(e))
|
||||||
|
except Exception:
|
||||||
|
error_detail = response.text
|
||||||
|
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||||
|
except requests.RequestException as e:
|
||||||
|
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||||
|
|
||||||
|
return EmbedResponse(**response.json())
|
||||||
|
|
||||||
|
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
|
||||||
|
if embed_request.text_type == EmbedTextType.PASSAGE:
|
||||||
|
return retry(tries=3, delay=5)(_make_request)()
|
||||||
|
else:
|
||||||
|
return _make_request()
|
||||||
|
|
||||||
|
def _encode_api_model(
|
||||||
|
self, texts: list[str], text_type: EmbedTextType, batch_size: int
|
||||||
|
) -> list[Embedding]:
|
||||||
|
if not self.provider_type:
|
||||||
|
raise ValueError("Provider type is not set for API embedding")
|
||||||
|
|
||||||
|
embeddings: list[Embedding] = []
|
||||||
|
|
||||||
|
text_batches = batch_list(texts, batch_size)
|
||||||
|
for idx, text_batch in enumerate(text_batches, start=1):
|
||||||
|
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
||||||
|
embed_request = EmbedRequest(
|
||||||
|
model_name=self.model_name,
|
||||||
|
texts=text_batch,
|
||||||
|
max_context_length=self.max_seq_length,
|
||||||
|
normalize_embeddings=self.normalize,
|
||||||
|
api_key=self.api_key,
|
||||||
|
provider_type=self.provider_type,
|
||||||
|
text_type=text_type,
|
||||||
|
manual_query_prefix=self.query_prefix,
|
||||||
|
manual_passage_prefix=self.passage_prefix,
|
||||||
|
)
|
||||||
|
response = self._make_model_server_request(embed_request)
|
||||||
|
embeddings.extend(response.embeddings)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _encode_local_model(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
text_type: EmbedTextType,
|
||||||
|
batch_size: int,
|
||||||
|
) -> list[Embedding]:
|
||||||
|
text_batches = batch_list(texts, batch_size)
|
||||||
|
embeddings: list[Embedding] = []
|
||||||
|
logger.debug(
|
||||||
|
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
|
||||||
|
)
|
||||||
|
for idx, text_batch in enumerate(text_batches, start=1):
|
||||||
|
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
||||||
|
embed_request = EmbedRequest(
|
||||||
|
model_name=self.model_name,
|
||||||
|
texts=text_batch,
|
||||||
|
max_context_length=self.max_seq_length,
|
||||||
|
normalize_embeddings=self.normalize,
|
||||||
|
api_key=self.api_key,
|
||||||
|
provider_type=self.provider_type,
|
||||||
|
text_type=text_type,
|
||||||
|
manual_query_prefix=self.query_prefix,
|
||||||
|
manual_passage_prefix=self.passage_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self._make_model_server_request(embed_request)
|
||||||
|
embeddings.extend(response.embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@@ -99,65 +180,14 @@ class EmbeddingModel:
|
|||||||
]
|
]
|
||||||
|
|
||||||
if self.provider_type:
|
if self.provider_type:
|
||||||
text_batches = batch_list(texts, api_embedding_batch_size)
|
return self._encode_api_model(
|
||||||
embed_request = EmbedRequest(
|
texts=texts, text_type=text_type, batch_size=api_embedding_batch_size
|
||||||
model_name=self.model_name,
|
|
||||||
texts=texts,
|
|
||||||
max_context_length=self.max_seq_length,
|
|
||||||
normalize_embeddings=self.normalize,
|
|
||||||
api_key=self.api_key,
|
|
||||||
provider_type=self.provider_type,
|
|
||||||
text_type=text_type,
|
|
||||||
manual_query_prefix=self.query_prefix,
|
|
||||||
manual_passage_prefix=self.passage_prefix,
|
|
||||||
)
|
)
|
||||||
response = requests.post(
|
|
||||||
self.embed_server_endpoint, json=embed_request.dict()
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response.raise_for_status()
|
|
||||||
except requests.HTTPError as e:
|
|
||||||
error_detail = response.json().get("detail", str(e))
|
|
||||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
|
||||||
except requests.RequestException as e:
|
|
||||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
|
||||||
|
|
||||||
return EmbedResponse(**response.json()).embeddings
|
# if no provider, use local model
|
||||||
|
return self._encode_local_model(
|
||||||
# Batching for local embedding
|
texts=texts, text_type=text_type, batch_size=local_embedding_batch_size
|
||||||
text_batches = batch_list(texts, local_embedding_batch_size)
|
|
||||||
embeddings: list[Embedding] = []
|
|
||||||
logger.debug(
|
|
||||||
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
|
|
||||||
)
|
)
|
||||||
for idx, text_batch in enumerate(text_batches, start=1):
|
|
||||||
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
|
||||||
embed_request = EmbedRequest(
|
|
||||||
model_name=self.model_name,
|
|
||||||
texts=text_batch,
|
|
||||||
max_context_length=self.max_seq_length,
|
|
||||||
normalize_embeddings=self.normalize,
|
|
||||||
api_key=self.api_key,
|
|
||||||
provider_type=self.provider_type,
|
|
||||||
text_type=text_type,
|
|
||||||
manual_query_prefix=self.query_prefix,
|
|
||||||
manual_passage_prefix=self.passage_prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
self.embed_server_endpoint, json=embed_request.dict()
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response.raise_for_status()
|
|
||||||
except requests.HTTPError as e:
|
|
||||||
error_detail = response.json().get("detail", str(e))
|
|
||||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
|
||||||
except requests.RequestException as e:
|
|
||||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
|
||||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
|
||||||
# value for the set loss
|
|
||||||
embeddings.extend(EmbedResponse(**response.json()).embeddings)
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderEnsembleModel:
|
class CrossEncoderEnsembleModel:
|
||||||
|
Reference in New Issue
Block a user