mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 11:58:34 +02:00
DAN-2 Backend Support for Filters (#8)
Additionally added an __init__.py for mypy issue
This commit is contained in:
parent
02a6677e21
commit
e7b901f292
0
backend/danswer/__init__.py
Normal file
0
backend/danswer/__init__.py
Normal file
@ -10,6 +10,7 @@ APP_PORT = 8080
|
||||
#####
|
||||
# Vector DB Configs
|
||||
#####
|
||||
DEFAULT_VECTOR_STORE = os.environ.get("VECTOR_DB", "qdrant")
|
||||
# Url / Key are used to connect to a remote Qdrant instance
|
||||
QDRANT_URL = os.environ.get("QDRANT_URL", "")
|
||||
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
|
||||
|
@ -1,12 +1,25 @@
|
||||
from typing import Type
|
||||
|
||||
from danswer.configs.app_configs import DEFAULT_VECTOR_STORE
|
||||
from danswer.datastores.interfaces import Datastore
|
||||
from danswer.datastores.qdrant.store import QdrantDatastore
|
||||
|
||||
|
||||
def get_selected_datastore_cls() -> Type[Datastore]:
|
||||
def get_selected_datastore_cls(
|
||||
vector_db_type: str = DEFAULT_VECTOR_STORE,
|
||||
) -> Type[Datastore]:
|
||||
"""Returns the selected Datastore cls. Only one datastore
|
||||
should be selected for a specific deployment."""
|
||||
# TOOD: when more datastores are added, look at env variable to
|
||||
# figure out which one should be returned
|
||||
return QdrantDatastore
|
||||
if vector_db_type == "qdrant":
|
||||
return QdrantDatastore
|
||||
else:
|
||||
raise ValueError(f"Invalid Vector DB setting: {vector_db_type}")
|
||||
|
||||
|
||||
def create_datastore(
|
||||
collection: str, vector_db_type: str = DEFAULT_VECTOR_STORE
|
||||
) -> Datastore:
|
||||
if vector_db_type == "qdrant":
|
||||
return QdrantDatastore(collection=collection)
|
||||
else:
|
||||
raise ValueError(f"Invalid Vector DB setting: {vector_db_type}")
|
||||
|
@ -3,6 +3,8 @@ import abc
|
||||
from danswer.chunking.models import EmbeddedIndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
|
||||
DatastoreFilter = dict[str, str | list[str] | None]
|
||||
|
||||
|
||||
class Datastore:
|
||||
@abc.abstractmethod
|
||||
@ -10,5 +12,7 @@ class Datastore:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def search(self, query: str, num_to_retrieve: int) -> list[InferenceChunk]:
|
||||
def semantic_retrieval(
|
||||
self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
@ -2,12 +2,17 @@ from danswer.chunking.models import EmbeddedIndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.datastores.interfaces import Datastore
|
||||
from danswer.datastores.interfaces import DatastoreFilter
|
||||
from danswer.datastores.qdrant.indexing import index_chunks
|
||||
from danswer.embedding.biencoder import get_default_model
|
||||
from danswer.utils.clients import get_qdrant_client
|
||||
from danswer.utils.logging import setup_logger
|
||||
from danswer.utils.timing import build_timing_wrapper
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import FieldCondition
|
||||
from qdrant_client.http.models import Filter
|
||||
from qdrant_client.http.models import MatchAny
|
||||
from qdrant_client.http.models import MatchValue
|
||||
|
||||
logger = setup_logger()
|
||||
@ -23,25 +28,63 @@ class QdrantDatastore(Datastore):
|
||||
chunks=chunks, collection=self.collection, client=self.client
|
||||
)
|
||||
|
||||
def search(self, query: str, num_to_retrieve: int) -> list[InferenceChunk]:
|
||||
@build_timing_wrapper()
|
||||
def semantic_retrieval(
|
||||
self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int
|
||||
) -> list[InferenceChunk]:
|
||||
query_embedding = get_default_model().encode(
|
||||
query
|
||||
) # TODO: make this part of the embedder interface
|
||||
hits = self.client.search(
|
||||
collection_name=self.collection,
|
||||
query_vector=query_embedding
|
||||
if isinstance(query_embedding, list)
|
||||
else query_embedding.tolist(),
|
||||
query_filter=None,
|
||||
limit=num_to_retrieve,
|
||||
)
|
||||
if not isinstance(query_embedding, list):
|
||||
query_embedding = query_embedding.tolist()
|
||||
|
||||
hits = []
|
||||
filter_conditions = []
|
||||
try:
|
||||
if filters:
|
||||
for filter_dict in filters:
|
||||
valid_filters = {
|
||||
key: value
|
||||
for key, value in filter_dict.items()
|
||||
if value is not None
|
||||
}
|
||||
for filter_key, filter_val in valid_filters.items():
|
||||
if isinstance(filter_val, str):
|
||||
filter_conditions.append(
|
||||
FieldCondition(
|
||||
key=filter_key,
|
||||
match=MatchValue(value=filter_val),
|
||||
)
|
||||
)
|
||||
elif isinstance(filter_val, list):
|
||||
filter_conditions.append(
|
||||
FieldCondition(
|
||||
key=filter_key,
|
||||
match=MatchAny(any=filter_val),
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid filters provided")
|
||||
|
||||
hits = self.client.search(
|
||||
collection_name=self.collection,
|
||||
query_vector=query_embedding,
|
||||
query_filter=Filter(must=list(filter_conditions)),
|
||||
limit=num_to_retrieve,
|
||||
)
|
||||
except ResponseHandlingException as e:
|
||||
logger.exception(f'Qdrant querying failed due to: "{e}", is Qdrant set up?')
|
||||
except UnexpectedResponse as e:
|
||||
logger.exception(
|
||||
f'Qdrant querying failed due to: "{e}", has ingestion been run?'
|
||||
)
|
||||
return [InferenceChunk.from_dict(hit.payload) for hit in hits]
|
||||
|
||||
def get_from_id(self, object_id: str) -> InferenceChunk | None:
|
||||
matches, _ = self.client.scroll(
|
||||
collection_name=self.collection,
|
||||
scroll_filter=Filter(
|
||||
should=[FieldCondition(key="id", match=MatchValue(value=object_id))]
|
||||
must=[FieldCondition(key="id", match=MatchValue(value=object_id))]
|
||||
),
|
||||
)
|
||||
if not matches:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
@ -10,11 +11,10 @@ from danswer.configs.model_configs import CROSS_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import MODEL_CACHE_FOLDER
|
||||
from danswer.configs.model_configs import QUERY_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.utils.clients import get_qdrant_client
|
||||
from danswer.datastores.interfaces import Datastore
|
||||
from danswer.datastores.interfaces import DatastoreFilter
|
||||
from danswer.utils.logging import setup_logger
|
||||
from danswer.utils.timing import build_timing_wrapper
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
@ -32,50 +32,13 @@ cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
|
||||
cross_encoder.max_length = CROSS_EMBED_CONTEXT_SIZE
|
||||
|
||||
|
||||
@build_timing_wrapper()
|
||||
def semantic_retrival(
|
||||
qdrant_collection: str,
|
||||
query: str,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
use_openai: bool = False,
|
||||
) -> List[InferenceChunk]:
|
||||
if use_openai:
|
||||
query_embedding = openai.Embedding.create(
|
||||
input=query, model="text-embedding-ada-002"
|
||||
)["data"][0]["embedding"]
|
||||
else:
|
||||
query_embedding = embedding_model.encode(query)
|
||||
try:
|
||||
hits = get_qdrant_client().search(
|
||||
collection_name=qdrant_collection,
|
||||
query_vector=query_embedding
|
||||
if isinstance(query_embedding, list)
|
||||
else query_embedding.tolist(),
|
||||
query_filter=None,
|
||||
limit=num_hits,
|
||||
)
|
||||
except ResponseHandlingException as e:
|
||||
logger.exception(f'Qdrant querying failed due to: "{e}", is Qdrant set up?')
|
||||
except UnexpectedResponse as e:
|
||||
logger.exception(
|
||||
f'Qdrant querying failed due to: "{e}", has ingestion been run?'
|
||||
)
|
||||
|
||||
retrieved_chunks = []
|
||||
for hit in hits:
|
||||
payload = hit.payload
|
||||
retrieved_chunks.append(InferenceChunk.from_dict(payload))
|
||||
|
||||
return retrieved_chunks
|
||||
|
||||
|
||||
@build_timing_wrapper()
|
||||
def semantic_reranking(
|
||||
query: str,
|
||||
chunks: List[InferenceChunk],
|
||||
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
|
||||
) -> List[InferenceChunk]:
|
||||
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks])
|
||||
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
|
||||
scored_results = list(zip(sim_scores, chunks))
|
||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||
ranked_sim_scores, ranked_chunks = zip(*scored_results)
|
||||
@ -88,11 +51,18 @@ def semantic_reranking(
|
||||
|
||||
|
||||
def semantic_search(
|
||||
qdrant_collection: str,
|
||||
query: str,
|
||||
filters: list[DatastoreFilter] | None,
|
||||
datastore: Datastore,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
|
||||
) -> List[InferenceChunk]:
|
||||
top_chunks = semantic_retrival(qdrant_collection, query, num_hits)
|
||||
) -> List[InferenceChunk] | None:
|
||||
top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
|
||||
if not top_chunks:
|
||||
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
|
||||
logger.warning(
|
||||
f"Semantic search returned no results with filters: {filters_log_msg}"
|
||||
)
|
||||
return None
|
||||
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
|
||||
return ranked_chunks
|
||||
|
@ -1,13 +1,12 @@
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Union
|
||||
|
||||
from danswer.configs.app_configs import DEFAULT_PROMPT
|
||||
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
||||
from danswer.configs.constants import CONTENT
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.datastores import create_datastore
|
||||
from danswer.datastores.interfaces import DatastoreFilter
|
||||
from danswer.direct_qa.qa_prompts import BASIC_QA_PROMPTS
|
||||
from danswer.direct_qa.question_answer import answer_question
|
||||
from danswer.direct_qa.question_answer import process_answer
|
||||
@ -30,15 +29,16 @@ class ServerStatus(BaseModel):
|
||||
class QAQuestion(BaseModel):
|
||||
query: str
|
||||
collection: str
|
||||
filters: list[DatastoreFilter] | None
|
||||
|
||||
|
||||
class QAResponse(BaseModel):
|
||||
answer: Union[str, None]
|
||||
quotes: Union[Dict[str, Dict[str, str]], None]
|
||||
answer: str | None
|
||||
quotes: dict[str, dict[str, str]] | None
|
||||
|
||||
|
||||
class KeywordResponse(BaseModel):
|
||||
results: Union[List[str], None]
|
||||
results: list[str] | None
|
||||
|
||||
|
||||
@router.get("/", response_model=ServerStatus)
|
||||
@ -52,16 +52,24 @@ def direct_qa(question: QAQuestion):
|
||||
prompt_processor = BASIC_QA_PROMPTS[DEFAULT_PROMPT]
|
||||
query = question.query
|
||||
collection = question.collection
|
||||
filters = question.filters
|
||||
|
||||
datastore = create_datastore(collection)
|
||||
|
||||
logger.info(f"Received semantic query: {query}")
|
||||
start_time = time.time()
|
||||
|
||||
ranked_chunks = semantic_search(collection, query)
|
||||
start_time = time.time()
|
||||
ranked_chunks = semantic_search(query, filters, datastore)
|
||||
sem_search_time = time.time()
|
||||
|
||||
logger.info(f"Semantic search took {sem_search_time - start_time} seconds")
|
||||
|
||||
if not ranked_chunks:
|
||||
return {"answer": None, "quotes": None}
|
||||
|
||||
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
|
||||
top_contents = [ranked_chunk.content for ranked_chunk in ranked_chunks]
|
||||
|
||||
logger.info(f"Semantic search took {sem_search_time - start_time} seconds")
|
||||
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
||||
logger.info(files_log_msg)
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import requests
|
||||
@ -8,35 +9,60 @@ from danswer.configs.constants import SOURCE_TYPE
|
||||
|
||||
if __name__ == "__main__":
|
||||
previous_query = None
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--keyword-search",
|
||||
action="store_true",
|
||||
help="Use keyword search if set, semantic search otherwise",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--source-types",
|
||||
type=str,
|
||||
help="Comma separated list of source types to filter by",
|
||||
)
|
||||
|
||||
parser.add_argument("query", nargs="*", help="The query to process")
|
||||
|
||||
while True:
|
||||
try:
|
||||
keyword_search = False
|
||||
query = input(
|
||||
"\n\nAsk any question:\n - prefix with -k for keyword search\n - input an empty string to "
|
||||
"rerun last query\n\t"
|
||||
user_input = input(
|
||||
"\n\nAsk any question:\n"
|
||||
" - prefix with -s to add a filter by source(s)\n"
|
||||
" - input an empty string to rerun last query\n\t"
|
||||
)
|
||||
|
||||
if query.lower() in ["q", "quit", "exit", "exit()"]:
|
||||
break
|
||||
|
||||
if query:
|
||||
previous_query = query
|
||||
if user_input:
|
||||
previous_input = user_input
|
||||
else:
|
||||
if not previous_query:
|
||||
print("No previous query")
|
||||
if not previous_input:
|
||||
print("No previous input")
|
||||
continue
|
||||
print(f"Re-executing previous question:\n\t{previous_query}")
|
||||
query = previous_query
|
||||
print(f"Re-executing previous question:\n\t{previous_input}")
|
||||
user_input = previous_input
|
||||
|
||||
args = parser.parse_args(user_input.split())
|
||||
|
||||
keyword_search = args.keyword_search
|
||||
source_types = args.source_types.split(",") if args.source_types else None
|
||||
if source_types and len(source_types) == 1:
|
||||
source_types = source_types[0]
|
||||
query = " ".join(args.query)
|
||||
|
||||
endpoint = f"http://127.0.0.1:{APP_PORT}/direct-qa"
|
||||
if query.startswith("-k "):
|
||||
keyword_search = True
|
||||
query = query[2:]
|
||||
if args.keyword_search:
|
||||
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
|
||||
|
||||
response = requests.post(
|
||||
endpoint, json={"query": query, "collection": QDRANT_DEFAULT_COLLECTION}
|
||||
)
|
||||
query_json = {
|
||||
"query": query,
|
||||
"collection": QDRANT_DEFAULT_COLLECTION,
|
||||
"filters": [{SOURCE_TYPE: source_types}],
|
||||
}
|
||||
|
||||
response = requests.post(endpoint, json=query_json)
|
||||
contents = json.loads(response.content)
|
||||
if keyword_search:
|
||||
if contents["results"]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user