mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-11 05:36:03 +02:00
DAN-2 Backend Support for Filters (#8)
Additionally added an __init__.py for mypy issue
This commit is contained in:
0
backend/danswer/__init__.py
Normal file
0
backend/danswer/__init__.py
Normal file
@@ -10,6 +10,7 @@ APP_PORT = 8080
|
|||||||
#####
|
#####
|
||||||
# Vector DB Configs
|
# Vector DB Configs
|
||||||
#####
|
#####
|
||||||
|
DEFAULT_VECTOR_STORE = os.environ.get("VECTOR_DB", "qdrant")
|
||||||
# Url / Key are used to connect to a remote Qdrant instance
|
# Url / Key are used to connect to a remote Qdrant instance
|
||||||
QDRANT_URL = os.environ.get("QDRANT_URL", "")
|
QDRANT_URL = os.environ.get("QDRANT_URL", "")
|
||||||
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
|
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
|
||||||
|
@@ -1,12 +1,25 @@
|
|||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import DEFAULT_VECTOR_STORE
|
||||||
from danswer.datastores.interfaces import Datastore
|
from danswer.datastores.interfaces import Datastore
|
||||||
from danswer.datastores.qdrant.store import QdrantDatastore
|
from danswer.datastores.qdrant.store import QdrantDatastore
|
||||||
|
|
||||||
|
|
||||||
def get_selected_datastore_cls() -> Type[Datastore]:
|
def get_selected_datastore_cls(
|
||||||
|
vector_db_type: str = DEFAULT_VECTOR_STORE,
|
||||||
|
) -> Type[Datastore]:
|
||||||
"""Returns the selected Datastore cls. Only one datastore
|
"""Returns the selected Datastore cls. Only one datastore
|
||||||
should be selected for a specific deployment."""
|
should be selected for a specific deployment."""
|
||||||
# TOOD: when more datastores are added, look at env variable to
|
if vector_db_type == "qdrant":
|
||||||
# figure out which one should be returned
|
return QdrantDatastore
|
||||||
return QdrantDatastore
|
else:
|
||||||
|
raise ValueError(f"Invalid Vector DB setting: {vector_db_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def create_datastore(
|
||||||
|
collection: str, vector_db_type: str = DEFAULT_VECTOR_STORE
|
||||||
|
) -> Datastore:
|
||||||
|
if vector_db_type == "qdrant":
|
||||||
|
return QdrantDatastore(collection=collection)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid Vector DB setting: {vector_db_type}")
|
||||||
|
@@ -3,6 +3,8 @@ import abc
|
|||||||
from danswer.chunking.models import EmbeddedIndexChunk
|
from danswer.chunking.models import EmbeddedIndexChunk
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
|
||||||
|
DatastoreFilter = dict[str, str | list[str] | None]
|
||||||
|
|
||||||
|
|
||||||
class Datastore:
|
class Datastore:
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -10,5 +12,7 @@ class Datastore:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def search(self, query: str, num_to_retrieve: int) -> list[InferenceChunk]:
|
def semantic_retrieval(
|
||||||
|
self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@@ -2,12 +2,17 @@ from danswer.chunking.models import EmbeddedIndexChunk
|
|||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||||
from danswer.datastores.interfaces import Datastore
|
from danswer.datastores.interfaces import Datastore
|
||||||
|
from danswer.datastores.interfaces import DatastoreFilter
|
||||||
from danswer.datastores.qdrant.indexing import index_chunks
|
from danswer.datastores.qdrant.indexing import index_chunks
|
||||||
from danswer.embedding.biencoder import get_default_model
|
from danswer.embedding.biencoder import get_default_model
|
||||||
from danswer.utils.clients import get_qdrant_client
|
from danswer.utils.clients import get_qdrant_client
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
|
from danswer.utils.timing import build_timing_wrapper
|
||||||
|
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||||
|
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||||
from qdrant_client.http.models import FieldCondition
|
from qdrant_client.http.models import FieldCondition
|
||||||
from qdrant_client.http.models import Filter
|
from qdrant_client.http.models import Filter
|
||||||
|
from qdrant_client.http.models import MatchAny
|
||||||
from qdrant_client.http.models import MatchValue
|
from qdrant_client.http.models import MatchValue
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -23,25 +28,63 @@ class QdrantDatastore(Datastore):
|
|||||||
chunks=chunks, collection=self.collection, client=self.client
|
chunks=chunks, collection=self.collection, client=self.client
|
||||||
)
|
)
|
||||||
|
|
||||||
def search(self, query: str, num_to_retrieve: int) -> list[InferenceChunk]:
|
@build_timing_wrapper()
|
||||||
|
def semantic_retrieval(
|
||||||
|
self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
query_embedding = get_default_model().encode(
|
query_embedding = get_default_model().encode(
|
||||||
query
|
query
|
||||||
) # TODO: make this part of the embedder interface
|
) # TODO: make this part of the embedder interface
|
||||||
hits = self.client.search(
|
if not isinstance(query_embedding, list):
|
||||||
collection_name=self.collection,
|
query_embedding = query_embedding.tolist()
|
||||||
query_vector=query_embedding
|
|
||||||
if isinstance(query_embedding, list)
|
hits = []
|
||||||
else query_embedding.tolist(),
|
filter_conditions = []
|
||||||
query_filter=None,
|
try:
|
||||||
limit=num_to_retrieve,
|
if filters:
|
||||||
)
|
for filter_dict in filters:
|
||||||
|
valid_filters = {
|
||||||
|
key: value
|
||||||
|
for key, value in filter_dict.items()
|
||||||
|
if value is not None
|
||||||
|
}
|
||||||
|
for filter_key, filter_val in valid_filters.items():
|
||||||
|
if isinstance(filter_val, str):
|
||||||
|
filter_conditions.append(
|
||||||
|
FieldCondition(
|
||||||
|
key=filter_key,
|
||||||
|
match=MatchValue(value=filter_val),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(filter_val, list):
|
||||||
|
filter_conditions.append(
|
||||||
|
FieldCondition(
|
||||||
|
key=filter_key,
|
||||||
|
match=MatchAny(any=filter_val),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid filters provided")
|
||||||
|
|
||||||
|
hits = self.client.search(
|
||||||
|
collection_name=self.collection,
|
||||||
|
query_vector=query_embedding,
|
||||||
|
query_filter=Filter(must=list(filter_conditions)),
|
||||||
|
limit=num_to_retrieve,
|
||||||
|
)
|
||||||
|
except ResponseHandlingException as e:
|
||||||
|
logger.exception(f'Qdrant querying failed due to: "{e}", is Qdrant set up?')
|
||||||
|
except UnexpectedResponse as e:
|
||||||
|
logger.exception(
|
||||||
|
f'Qdrant querying failed due to: "{e}", has ingestion been run?'
|
||||||
|
)
|
||||||
return [InferenceChunk.from_dict(hit.payload) for hit in hits]
|
return [InferenceChunk.from_dict(hit.payload) for hit in hits]
|
||||||
|
|
||||||
def get_from_id(self, object_id: str) -> InferenceChunk | None:
|
def get_from_id(self, object_id: str) -> InferenceChunk | None:
|
||||||
matches, _ = self.client.scroll(
|
matches, _ = self.client.scroll(
|
||||||
collection_name=self.collection,
|
collection_name=self.collection,
|
||||||
scroll_filter=Filter(
|
scroll_filter=Filter(
|
||||||
should=[FieldCondition(key="id", match=MatchValue(value=object_id))]
|
must=[FieldCondition(key="id", match=MatchValue(value=object_id))]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not matches:
|
if not matches:
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
@@ -10,11 +11,10 @@ from danswer.configs.model_configs import CROSS_ENCODER_MODEL
|
|||||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||||
from danswer.configs.model_configs import MODEL_CACHE_FOLDER
|
from danswer.configs.model_configs import MODEL_CACHE_FOLDER
|
||||||
from danswer.configs.model_configs import QUERY_EMBEDDING_CONTEXT_SIZE
|
from danswer.configs.model_configs import QUERY_EMBEDDING_CONTEXT_SIZE
|
||||||
from danswer.utils.clients import get_qdrant_client
|
from danswer.datastores.interfaces import Datastore
|
||||||
|
from danswer.datastores.interfaces import DatastoreFilter
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from danswer.utils.timing import build_timing_wrapper
|
from danswer.utils.timing import build_timing_wrapper
|
||||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
|
||||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
|
||||||
from sentence_transformers import CrossEncoder # type: ignore
|
from sentence_transformers import CrossEncoder # type: ignore
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
|
|
||||||
@@ -32,50 +32,13 @@ cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
|
|||||||
cross_encoder.max_length = CROSS_EMBED_CONTEXT_SIZE
|
cross_encoder.max_length = CROSS_EMBED_CONTEXT_SIZE
|
||||||
|
|
||||||
|
|
||||||
@build_timing_wrapper()
|
|
||||||
def semantic_retrival(
|
|
||||||
qdrant_collection: str,
|
|
||||||
query: str,
|
|
||||||
num_hits: int = NUM_RETURNED_HITS,
|
|
||||||
use_openai: bool = False,
|
|
||||||
) -> List[InferenceChunk]:
|
|
||||||
if use_openai:
|
|
||||||
query_embedding = openai.Embedding.create(
|
|
||||||
input=query, model="text-embedding-ada-002"
|
|
||||||
)["data"][0]["embedding"]
|
|
||||||
else:
|
|
||||||
query_embedding = embedding_model.encode(query)
|
|
||||||
try:
|
|
||||||
hits = get_qdrant_client().search(
|
|
||||||
collection_name=qdrant_collection,
|
|
||||||
query_vector=query_embedding
|
|
||||||
if isinstance(query_embedding, list)
|
|
||||||
else query_embedding.tolist(),
|
|
||||||
query_filter=None,
|
|
||||||
limit=num_hits,
|
|
||||||
)
|
|
||||||
except ResponseHandlingException as e:
|
|
||||||
logger.exception(f'Qdrant querying failed due to: "{e}", is Qdrant set up?')
|
|
||||||
except UnexpectedResponse as e:
|
|
||||||
logger.exception(
|
|
||||||
f'Qdrant querying failed due to: "{e}", has ingestion been run?'
|
|
||||||
)
|
|
||||||
|
|
||||||
retrieved_chunks = []
|
|
||||||
for hit in hits:
|
|
||||||
payload = hit.payload
|
|
||||||
retrieved_chunks.append(InferenceChunk.from_dict(payload))
|
|
||||||
|
|
||||||
return retrieved_chunks
|
|
||||||
|
|
||||||
|
|
||||||
@build_timing_wrapper()
|
@build_timing_wrapper()
|
||||||
def semantic_reranking(
|
def semantic_reranking(
|
||||||
query: str,
|
query: str,
|
||||||
chunks: List[InferenceChunk],
|
chunks: List[InferenceChunk],
|
||||||
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
|
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
|
||||||
) -> List[InferenceChunk]:
|
) -> List[InferenceChunk]:
|
||||||
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks])
|
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
|
||||||
scored_results = list(zip(sim_scores, chunks))
|
scored_results = list(zip(sim_scores, chunks))
|
||||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||||
ranked_sim_scores, ranked_chunks = zip(*scored_results)
|
ranked_sim_scores, ranked_chunks = zip(*scored_results)
|
||||||
@@ -88,11 +51,18 @@ def semantic_reranking(
|
|||||||
|
|
||||||
|
|
||||||
def semantic_search(
|
def semantic_search(
|
||||||
qdrant_collection: str,
|
|
||||||
query: str,
|
query: str,
|
||||||
|
filters: list[DatastoreFilter] | None,
|
||||||
|
datastore: Datastore,
|
||||||
num_hits: int = NUM_RETURNED_HITS,
|
num_hits: int = NUM_RETURNED_HITS,
|
||||||
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
|
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
|
||||||
) -> List[InferenceChunk]:
|
) -> List[InferenceChunk] | None:
|
||||||
top_chunks = semantic_retrival(qdrant_collection, query, num_hits)
|
top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
|
||||||
|
if not top_chunks:
|
||||||
|
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
|
||||||
|
logger.warning(
|
||||||
|
f"Semantic search returned no results with filters: {filters_log_msg}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
|
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
|
||||||
return ranked_chunks
|
return ranked_chunks
|
||||||
|
@@ -1,13 +1,12 @@
|
|||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict
|
|
||||||
from typing import List
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from danswer.configs.app_configs import DEFAULT_PROMPT
|
from danswer.configs.app_configs import DEFAULT_PROMPT
|
||||||
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
||||||
from danswer.configs.constants import CONTENT
|
from danswer.configs.constants import CONTENT
|
||||||
from danswer.configs.constants import SOURCE_LINKS
|
from danswer.configs.constants import SOURCE_LINKS
|
||||||
|
from danswer.datastores import create_datastore
|
||||||
|
from danswer.datastores.interfaces import DatastoreFilter
|
||||||
from danswer.direct_qa.qa_prompts import BASIC_QA_PROMPTS
|
from danswer.direct_qa.qa_prompts import BASIC_QA_PROMPTS
|
||||||
from danswer.direct_qa.question_answer import answer_question
|
from danswer.direct_qa.question_answer import answer_question
|
||||||
from danswer.direct_qa.question_answer import process_answer
|
from danswer.direct_qa.question_answer import process_answer
|
||||||
@@ -30,15 +29,16 @@ class ServerStatus(BaseModel):
|
|||||||
class QAQuestion(BaseModel):
|
class QAQuestion(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
collection: str
|
collection: str
|
||||||
|
filters: list[DatastoreFilter] | None
|
||||||
|
|
||||||
|
|
||||||
class QAResponse(BaseModel):
|
class QAResponse(BaseModel):
|
||||||
answer: Union[str, None]
|
answer: str | None
|
||||||
quotes: Union[Dict[str, Dict[str, str]], None]
|
quotes: dict[str, dict[str, str]] | None
|
||||||
|
|
||||||
|
|
||||||
class KeywordResponse(BaseModel):
|
class KeywordResponse(BaseModel):
|
||||||
results: Union[List[str], None]
|
results: list[str] | None
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=ServerStatus)
|
@router.get("/", response_model=ServerStatus)
|
||||||
@@ -52,16 +52,24 @@ def direct_qa(question: QAQuestion):
|
|||||||
prompt_processor = BASIC_QA_PROMPTS[DEFAULT_PROMPT]
|
prompt_processor = BASIC_QA_PROMPTS[DEFAULT_PROMPT]
|
||||||
query = question.query
|
query = question.query
|
||||||
collection = question.collection
|
collection = question.collection
|
||||||
|
filters = question.filters
|
||||||
|
|
||||||
|
datastore = create_datastore(collection)
|
||||||
|
|
||||||
logger.info(f"Received semantic query: {query}")
|
logger.info(f"Received semantic query: {query}")
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
ranked_chunks = semantic_search(collection, query)
|
start_time = time.time()
|
||||||
|
ranked_chunks = semantic_search(query, filters, datastore)
|
||||||
sem_search_time = time.time()
|
sem_search_time = time.time()
|
||||||
|
|
||||||
|
logger.info(f"Semantic search took {sem_search_time - start_time} seconds")
|
||||||
|
|
||||||
|
if not ranked_chunks:
|
||||||
|
return {"answer": None, "quotes": None}
|
||||||
|
|
||||||
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
|
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
|
||||||
top_contents = [ranked_chunk.content for ranked_chunk in ranked_chunks]
|
top_contents = [ranked_chunk.content for ranked_chunk in ranked_chunks]
|
||||||
|
|
||||||
logger.info(f"Semantic search took {sem_search_time - start_time} seconds")
|
|
||||||
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
||||||
logger.info(files_log_msg)
|
logger.info(files_log_msg)
|
||||||
|
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -8,35 +9,60 @@ from danswer.configs.constants import SOURCE_TYPE
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
previous_query = None
|
previous_query = None
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-k",
|
||||||
|
"--keyword-search",
|
||||||
|
action="store_true",
|
||||||
|
help="Use keyword search if set, semantic search otherwise",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--source-types",
|
||||||
|
type=str,
|
||||||
|
help="Comma separated list of source types to filter by",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("query", nargs="*", help="The query to process")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
keyword_search = False
|
user_input = input(
|
||||||
query = input(
|
"\n\nAsk any question:\n"
|
||||||
"\n\nAsk any question:\n - prefix with -k for keyword search\n - input an empty string to "
|
" - prefix with -s to add a filter by source(s)\n"
|
||||||
"rerun last query\n\t"
|
" - input an empty string to rerun last query\n\t"
|
||||||
)
|
)
|
||||||
|
|
||||||
if query.lower() in ["q", "quit", "exit", "exit()"]:
|
if user_input:
|
||||||
break
|
previous_input = user_input
|
||||||
|
|
||||||
if query:
|
|
||||||
previous_query = query
|
|
||||||
else:
|
else:
|
||||||
if not previous_query:
|
if not previous_input:
|
||||||
print("No previous query")
|
print("No previous input")
|
||||||
continue
|
continue
|
||||||
print(f"Re-executing previous question:\n\t{previous_query}")
|
print(f"Re-executing previous question:\n\t{previous_input}")
|
||||||
query = previous_query
|
user_input = previous_input
|
||||||
|
|
||||||
|
args = parser.parse_args(user_input.split())
|
||||||
|
|
||||||
|
keyword_search = args.keyword_search
|
||||||
|
source_types = args.source_types.split(",") if args.source_types else None
|
||||||
|
if source_types and len(source_types) == 1:
|
||||||
|
source_types = source_types[0]
|
||||||
|
query = " ".join(args.query)
|
||||||
|
|
||||||
endpoint = f"http://127.0.0.1:{APP_PORT}/direct-qa"
|
endpoint = f"http://127.0.0.1:{APP_PORT}/direct-qa"
|
||||||
if query.startswith("-k "):
|
if args.keyword_search:
|
||||||
keyword_search = True
|
|
||||||
query = query[2:]
|
|
||||||
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
|
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
|
||||||
|
|
||||||
response = requests.post(
|
query_json = {
|
||||||
endpoint, json={"query": query, "collection": QDRANT_DEFAULT_COLLECTION}
|
"query": query,
|
||||||
)
|
"collection": QDRANT_DEFAULT_COLLECTION,
|
||||||
|
"filters": [{SOURCE_TYPE: source_types}],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(endpoint, json=query_json)
|
||||||
contents = json.loads(response.content)
|
contents = json.loads(response.content)
|
||||||
if keyword_search:
|
if keyword_search:
|
||||||
if contents["results"]:
|
if contents["results"]:
|
||||||
|
Reference in New Issue
Block a user