Merge pull request #12050 from mahenning/fix-db-order

Fix: Normalize all database distances to score in [0, 1] (needs testing for different DBs)
This commit is contained in:
Timothy Jaeryang Baek 2025-03-26 20:56:02 -07:00 committed by GitHub
commit 3514a6c5ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 22 additions and 29 deletions

View File

@ -185,9 +185,7 @@ def merge_get_results(get_results: list[dict]) -> dict:
return result
def merge_and_sort_query_results(
query_results: list[dict], k: int, reverse: bool = False
) -> dict:
def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
# Initialize lists to store combined data
combined = dict() # To store documents with unique document hashes
@ -207,28 +205,18 @@ def merge_and_sort_query_results(
continue # if doc is new, no further comparison is needed
# if doc is alredy in, but new distance is better, update
if not reverse and distance < combined[doc_hash][0]:
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
combined[doc_hash] = (distance, document, metadata)
if reverse and distance > combined[doc_hash][0]:
if distance > combined[doc_hash][0]:
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=reverse)
combined.sort(key=lambda x: x[0], reverse=True)
# Slice to keep only the top k elements
sorted_distances, sorted_documents, sorted_metadatas = (
zip(*combined[:k]) if combined else ([], [], [])
)
# if chromaDB, the distance is 0 (best) to 2 (worse)
# re-order to -1 (worst) to 1 (best) for relevance score
if not reverse:
sorted_distances = tuple(-dist for dist in sorted_distances)
sorted_distances = tuple(dist + 1 for dist in sorted_distances)
# Create and return the output dictionary
return {
"distances": [list(sorted_distances)],
@ -278,12 +266,7 @@ def query_collection(
else:
pass
if VECTOR_DB == "chroma":
# Chroma uses unconventional cosine similarity, so we don't need to reverse the results
# https://docs.trychroma.com/docs/collections/configure#configuring-chroma-collections
return merge_and_sort_query_results(results, k=k, reverse=False)
else:
return merge_and_sort_query_results(results, k=k, reverse=True)
return merge_and_sort_query_results(results, k=k)
def query_collection_with_hybrid_search(
@ -320,9 +303,7 @@ def query_collection_with_hybrid_search(
raise Exception(
"Hybrid search failed for all collections. Using Non hybrid search as fallback."
)
return merge_and_sort_query_results(results, k=k, reverse=True)
return merge_and_sort_query_results(results, k=k)
def get_embedding_function(
embedding_engine,

View File

@ -75,10 +75,16 @@ class ChromaClient:
n_results=limit,
)
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
# https://docs.trychroma.com/docs/collections/configure cosine equation
distances: list = result["distances"][0]
distances = [2 - dist for dist in distances]
distances = [[dist / 2 for dist in distances]]
return SearchResult(
**{
"ids": result["ids"],
"distances": result["distances"],
"distances": distances,
"documents": result["documents"],
"metadatas": result["metadatas"],
}

View File

@ -64,7 +64,10 @@ class MilvusClient:
for item in match:
_ids.append(item.get("id"))
_distances.append(item.get("distance"))
# normalize milvus score from [-1, 1] to [0, 1] range
# https://milvus.io/docs/de/metric.md
_dist = (item.get("distance") + 1.0) / 2.0
_distances.append(_dist)
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata"))

View File

@ -120,7 +120,7 @@ class OpenSearchClient:
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0",
"source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
"params": {
"field": "vector",
"query_value": vectors[0],

View File

@ -278,7 +278,9 @@ class PgvectorClient:
for row in results:
qid = int(row.qid)
ids[qid].append(row.id)
distances[qid].append(row.distance)
# normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
# https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
distances[qid].append((2.0 - row.distance) / 2.0)
documents[qid].append(row.text)
metadatas[qid].append(row.vmetadata)

View File

@ -99,7 +99,8 @@ class QdrantClient:
ids=get_result.ids,
documents=get_result.documents,
metadatas=get_result.metadatas,
distances=[[point.score for point in query_response.points]],
# qdrant distance is [-1, 1], normalize to [0, 1]
distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
)
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None):