diff --git a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py index a558e1fb0..c89628494 100644 --- a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py +++ b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py @@ -1,30 +1,28 @@ from elasticsearch import Elasticsearch, BadRequestError from typing import Optional import ssl -from elasticsearch.helpers import bulk,scan +from elasticsearch.helpers import bulk, scan from open_webui.retrieval.vector.main import VectorItem, SearchResult, GetResult from open_webui.config import ( ELASTICSEARCH_URL, - ELASTICSEARCH_CA_CERTS, + ELASTICSEARCH_CA_CERTS, ELASTICSEARCH_API_KEY, ELASTICSEARCH_USERNAME, - ELASTICSEARCH_PASSWORD, + ELASTICSEARCH_PASSWORD, ELASTICSEARCH_CLOUD_ID, ELASTICSEARCH_INDEX_PREFIX, SSL_ASSERT_FINGERPRINT, - ) - - class ElasticsearchClient: """ Important: - in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating - an index for each file but store it as a text field, while seperating to different index + in order to reduce the number of indexes and since the embedding vector length is fixed, we avoid creating + an index for each file but store it as a text field, while seperating to different index baesd on the embedding length. """ + def __init__(self): self.index_prefix = ELASTICSEARCH_INDEX_PREFIX self.client = Elasticsearch( @@ -32,15 +30,19 @@ class ElasticsearchClient: ca_certs=ELASTICSEARCH_CA_CERTS, api_key=ELASTICSEARCH_API_KEY, cloud_id=ELASTICSEARCH_CLOUD_ID, - basic_auth=(ELASTICSEARCH_USERNAME,ELASTICSEARCH_PASSWORD) if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD else None, - ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT - + basic_auth=( + (ELASTICSEARCH_USERNAME, ELASTICSEARCH_PASSWORD) + if ELASTICSEARCH_USERNAME and ELASTICSEARCH_PASSWORD + else None + ), + ssl_assert_fingerprint=SSL_ASSERT_FINGERPRINT, ) - #Status: works - def _get_index_name(self,dimension:int)->str: + + # Status: works + def _get_index_name(self, dimension: int) -> str: return f"{self.index_prefix}_d{str(dimension)}" - - #Status: works + + # Status: works def _scan_result_to_get_result(self, result) -> GetResult: if not result: return None @@ -55,7 +57,7 @@ class ElasticsearchClient: return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) - #Status: works + # Status: works def _result_to_get_result(self, result) -> GetResult: if not result["hits"]["hits"]: return None @@ -70,7 +72,7 @@ class ElasticsearchClient: return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) - #Status: works + # Status: works def _result_to_search_result(self, result) -> SearchResult: ids = [] distances = [] @@ -84,19 +86,21 @@ class ElasticsearchClient: metadatas.append(hit["_source"].get("metadata")) return SearchResult( - ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas] + ids=[ids], + distances=[distances], + documents=[documents], + metadatas=[metadatas], ) - #Status: works + + # Status: works def _create_index(self, dimension: int): body = { "mappings": { "dynamic_templates": [ { - "strings": { - "match_mapping_type": "string", - "mapping": { - "type": "keyword" - } + "strings": { + "match_mapping_type": "string", + "mapping": {"type": "keyword"}, } } ], @@ -111,68 +115,52 @@ class ElasticsearchClient: }, "text": {"type": "text"}, "metadata": {"type": "object"}, - } + }, } } self.client.indices.create(index=self._get_index_name(dimension), body=body) - #Status: works + + # Status: works def _create_batches(self, items: list[VectorItem], batch_size=100): for i in range(0, len(items), batch_size): - yield items[i : min(i + batch_size,len(items))] + yield items[i : min(i + batch_size, len(items))] - #Status: works - def has_collection(self,collection_name) -> bool: + # Status: works + def has_collection(self, collection_name) -> bool: query_body = {"query": {"bool": {"filter": []}}} - query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}}) + query_body["query"]["bool"]["filter"].append( + {"term": {"collection": collection_name}} + ) try: - result = self.client.count( - index=f"{self.index_prefix}*", - body=query_body - ) - - return result.body["count"]>0 + result = self.client.count(index=f"{self.index_prefix}*", body=query_body) + + return result.body["count"] > 0 except Exception as e: return None - - def delete_collection(self, collection_name: str): - query = { - "query": { - "term": {"collection": collection_name} - } - } + query = {"query": {"term": {"collection": collection_name}}} self.client.delete_by_query(index=f"{self.index_prefix}*", body=query) - #Status: works + + # Status: works def search( self, collection_name: str, vectors: list[list[float]], limit: int ) -> Optional[SearchResult]: query = { "size": limit, - "_source": [ - "text", - "metadata" - ], + "_source": ["text", "metadata"], "query": { "script_score": { "query": { - "bool": { - "filter": [ - { - "term": { - "collection": collection_name - } - } - ] - } + "bool": {"filter": [{"term": {"collection": collection_name}}]} }, "script": { "source": "cosineSimilarity(params.vector, 'vector') + 1.0", "params": { "vector": vectors[0] - }, # Assuming single query vector + }, # Assuming single query vector }, } }, @@ -183,7 +171,8 @@ class ElasticsearchClient: ) return self._result_to_search_result(result) - #Status: only tested halfwat + + # Status: only tested halfwat def query( self, collection_name: str, filter: dict, limit: Optional[int] = None ) -> Optional[GetResult]: @@ -197,7 +186,9 @@ class ElasticsearchClient: for field, value in filter.items(): query_body["query"]["bool"]["filter"].append({"term": {field: value}}) - query_body["query"]["bool"]["filter"].append({"term": {"collection": collection_name}}) + query_body["query"]["bool"]["filter"].append( + {"term": {"collection": collection_name}} + ) size = limit if limit else 10 try: @@ -206,59 +197,53 @@ class ElasticsearchClient: body=query_body, size=size, ) - + return self._result_to_get_result(result) except Exception as e: return None - #Status: works - def _has_index(self,dimension:int): - return self.client.indices.exists(index=self._get_index_name(dimension=dimension)) + # Status: works + def _has_index(self, dimension: int): + return self.client.indices.exists( + index=self._get_index_name(dimension=dimension) + ) def get_or_create_index(self, dimension: int): if not self._has_index(dimension=dimension): self._create_index(dimension=dimension) - #Status: works + + # Status: works def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection. query = { - "query": { - "bool": { - "filter": [ - { - "term": { - "collection": collection_name - } - } - ] - } - }, "_source": ["text", "metadata"]} + "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}, + "_source": ["text", "metadata"], + } results = list(scan(self.client, index=f"{self.index_prefix}*", query=query)) - + return self._scan_result_to_get_result(results) - #Status: works + # Status: works def insert(self, collection_name: str, items: list[VectorItem]): if not self._has_index(dimension=len(items[0]["vector"])): self._create_index(dimension=len(items[0]["vector"])) - for batch in self._create_batches(items): actions = [ - { - "_index":self._get_index_name(dimension=len(items[0]["vector"])), - "_id": item["id"], - "_source": { - "collection": collection_name, - "vector": item["vector"], - "text": item["text"], - "metadata": item["metadata"], - }, - } + { + "_index": self._get_index_name(dimension=len(items[0]["vector"])), + "_id": item["id"], + "_source": { + "collection": collection_name, + "vector": item["vector"], + "text": item["text"], + "metadata": item["metadata"], + }, + } for item in batch ] - bulk(self.client,actions) + bulk(self.client, actions) # Upsert documents using the update API with doc_as_upsert=True. def upsert(self, collection_name: str, items: list[VectorItem]): @@ -280,8 +265,7 @@ class ElasticsearchClient: } for item in batch ] - bulk(self.client,actions) - + bulk(self.client, actions) # Delete specific documents from a collection by filtering on both collection and document IDs. def delete( @@ -292,21 +276,16 @@ class ElasticsearchClient: ): query = { - "query": { - "bool": { - "filter": [ - {"term": {"collection": collection_name}} - ] - } - } + "query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}} } - #logic based on chromaDB + # logic based on chromaDB if ids: query["query"]["bool"]["filter"].append({"terms": {"_id": ids}}) elif filter: for field, value in filter.items(): - query["query"]["bool"]["filter"].append({"term": {f"metadata.{field}": value}}) - + query["query"]["bool"]["filter"].append( + {"term": {f"metadata.{field}": value}} + ) self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)