DAN-2 Backend Support for Filters (#8)

Additionally added an __init__.py for mypy issue
This commit is contained in:
Yuhong Sun 2023-05-01 22:29:09 -07:00 committed by GitHub
parent 02a6677e21
commit e7b901f292
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 152 additions and 87 deletions

View File

View File

@ -10,6 +10,7 @@ APP_PORT = 8080
#####
# Vector DB Configs
#####
DEFAULT_VECTOR_STORE = os.environ.get("VECTOR_DB", "qdrant")
# Url / Key are used to connect to a remote Qdrant instance
QDRANT_URL = os.environ.get("QDRANT_URL", "")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")

View File

@ -1,12 +1,25 @@
from typing import Type
from danswer.configs.app_configs import DEFAULT_VECTOR_STORE
from danswer.datastores.interfaces import Datastore
from danswer.datastores.qdrant.store import QdrantDatastore
def get_selected_datastore_cls() -> Type[Datastore]:
def get_selected_datastore_cls(
vector_db_type: str = DEFAULT_VECTOR_STORE,
) -> Type[Datastore]:
"""Returns the selected Datastore cls. Only one datastore
should be selected for a specific deployment."""
# TOOD: when more datastores are added, look at env variable to
# figure out which one should be returned
return QdrantDatastore
if vector_db_type == "qdrant":
return QdrantDatastore
else:
raise ValueError(f"Invalid Vector DB setting: {vector_db_type}")
def create_datastore(
collection: str, vector_db_type: str = DEFAULT_VECTOR_STORE
) -> Datastore:
if vector_db_type == "qdrant":
return QdrantDatastore(collection=collection)
else:
raise ValueError(f"Invalid Vector DB setting: {vector_db_type}")

View File

@ -3,6 +3,8 @@ import abc
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import InferenceChunk
DatastoreFilter = dict[str, str | list[str] | None]
class Datastore:
@abc.abstractmethod
@ -10,5 +12,7 @@ class Datastore:
raise NotImplementedError
@abc.abstractmethod
def search(self, query: str, num_to_retrieve: int) -> list[InferenceChunk]:
def semantic_retrieval(
self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int
) -> list[InferenceChunk]:
raise NotImplementedError

View File

@ -2,12 +2,17 @@ from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.datastores.interfaces import Datastore
from danswer.datastores.interfaces import DatastoreFilter
from danswer.datastores.qdrant.indexing import index_chunks
from danswer.embedding.biencoder import get_default_model
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logging import setup_logger
from danswer.utils.timing import build_timing_wrapper
from qdrant_client.http.exceptions import ResponseHandlingException
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import FieldCondition
from qdrant_client.http.models import Filter
from qdrant_client.http.models import MatchAny
from qdrant_client.http.models import MatchValue
logger = setup_logger()
@ -23,25 +28,63 @@ class QdrantDatastore(Datastore):
chunks=chunks, collection=self.collection, client=self.client
)
def search(self, query: str, num_to_retrieve: int) -> list[InferenceChunk]:
@build_timing_wrapper()
def semantic_retrieval(
self, query: str, filters: list[DatastoreFilter] | None, num_to_retrieve: int
) -> list[InferenceChunk]:
query_embedding = get_default_model().encode(
query
) # TODO: make this part of the embedder interface
hits = self.client.search(
collection_name=self.collection,
query_vector=query_embedding
if isinstance(query_embedding, list)
else query_embedding.tolist(),
query_filter=None,
limit=num_to_retrieve,
)
if not isinstance(query_embedding, list):
query_embedding = query_embedding.tolist()
hits = []
filter_conditions = []
try:
if filters:
for filter_dict in filters:
valid_filters = {
key: value
for key, value in filter_dict.items()
if value is not None
}
for filter_key, filter_val in valid_filters.items():
if isinstance(filter_val, str):
filter_conditions.append(
FieldCondition(
key=filter_key,
match=MatchValue(value=filter_val),
)
)
elif isinstance(filter_val, list):
filter_conditions.append(
FieldCondition(
key=filter_key,
match=MatchAny(any=filter_val),
)
)
else:
raise ValueError("Invalid filters provided")
hits = self.client.search(
collection_name=self.collection,
query_vector=query_embedding,
query_filter=Filter(must=list(filter_conditions)),
limit=num_to_retrieve,
)
except ResponseHandlingException as e:
logger.exception(f'Qdrant querying failed due to: "{e}", is Qdrant set up?')
except UnexpectedResponse as e:
logger.exception(
f'Qdrant querying failed due to: "{e}", has ingestion been run?'
)
return [InferenceChunk.from_dict(hit.payload) for hit in hits]
def get_from_id(self, object_id: str) -> InferenceChunk | None:
matches, _ = self.client.scroll(
collection_name=self.collection,
scroll_filter=Filter(
should=[FieldCondition(key="id", match=MatchValue(value=object_id))]
must=[FieldCondition(key="id", match=MatchValue(value=object_id))]
),
)
if not matches:

View File

@ -1,3 +1,4 @@
import json
from typing import List
import openai
@ -10,11 +11,10 @@ from danswer.configs.model_configs import CROSS_ENCODER_MODEL
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import MODEL_CACHE_FOLDER
from danswer.configs.model_configs import QUERY_EMBEDDING_CONTEXT_SIZE
from danswer.utils.clients import get_qdrant_client
from danswer.datastores.interfaces import Datastore
from danswer.datastores.interfaces import DatastoreFilter
from danswer.utils.logging import setup_logger
from danswer.utils.timing import build_timing_wrapper
from qdrant_client.http.exceptions import ResponseHandlingException
from qdrant_client.http.exceptions import UnexpectedResponse
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
@ -32,50 +32,13 @@ cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
cross_encoder.max_length = CROSS_EMBED_CONTEXT_SIZE
@build_timing_wrapper()
def semantic_retrival(
qdrant_collection: str,
query: str,
num_hits: int = NUM_RETURNED_HITS,
use_openai: bool = False,
) -> List[InferenceChunk]:
if use_openai:
query_embedding = openai.Embedding.create(
input=query, model="text-embedding-ada-002"
)["data"][0]["embedding"]
else:
query_embedding = embedding_model.encode(query)
try:
hits = get_qdrant_client().search(
collection_name=qdrant_collection,
query_vector=query_embedding
if isinstance(query_embedding, list)
else query_embedding.tolist(),
query_filter=None,
limit=num_hits,
)
except ResponseHandlingException as e:
logger.exception(f'Qdrant querying failed due to: "{e}", is Qdrant set up?')
except UnexpectedResponse as e:
logger.exception(
f'Qdrant querying failed due to: "{e}", has ingestion been run?'
)
retrieved_chunks = []
for hit in hits:
payload = hit.payload
retrieved_chunks.append(InferenceChunk.from_dict(payload))
return retrieved_chunks
@build_timing_wrapper()
def semantic_reranking(
query: str,
chunks: List[InferenceChunk],
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
) -> List[InferenceChunk]:
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks])
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
scored_results = list(zip(sim_scores, chunks))
scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_chunks = zip(*scored_results)
@ -88,11 +51,18 @@ def semantic_reranking(
def semantic_search(
qdrant_collection: str,
query: str,
filters: list[DatastoreFilter] | None,
datastore: Datastore,
num_hits: int = NUM_RETURNED_HITS,
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
) -> List[InferenceChunk]:
top_chunks = semantic_retrival(qdrant_collection, query, num_hits)
) -> List[InferenceChunk] | None:
top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
if not top_chunks:
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
logger.warning(
f"Semantic search returned no results with filters: {filters_log_msg}"
)
return None
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
return ranked_chunks

View File

@ -1,13 +1,12 @@
import time
from http import HTTPStatus
from typing import Dict
from typing import List
from typing import Union
from danswer.configs.app_configs import DEFAULT_PROMPT
from danswer.configs.app_configs import KEYWORD_MAX_HITS
from danswer.configs.constants import CONTENT
from danswer.configs.constants import SOURCE_LINKS
from danswer.datastores import create_datastore
from danswer.datastores.interfaces import DatastoreFilter
from danswer.direct_qa.qa_prompts import BASIC_QA_PROMPTS
from danswer.direct_qa.question_answer import answer_question
from danswer.direct_qa.question_answer import process_answer
@ -30,15 +29,16 @@ class ServerStatus(BaseModel):
class QAQuestion(BaseModel):
query: str
collection: str
filters: list[DatastoreFilter] | None
class QAResponse(BaseModel):
answer: Union[str, None]
quotes: Union[Dict[str, Dict[str, str]], None]
answer: str | None
quotes: dict[str, dict[str, str]] | None
class KeywordResponse(BaseModel):
results: Union[List[str], None]
results: list[str] | None
@router.get("/", response_model=ServerStatus)
@ -52,16 +52,24 @@ def direct_qa(question: QAQuestion):
prompt_processor = BASIC_QA_PROMPTS[DEFAULT_PROMPT]
query = question.query
collection = question.collection
filters = question.filters
datastore = create_datastore(collection)
logger.info(f"Received semantic query: {query}")
start_time = time.time()
ranked_chunks = semantic_search(collection, query)
start_time = time.time()
ranked_chunks = semantic_search(query, filters, datastore)
sem_search_time = time.time()
logger.info(f"Semantic search took {sem_search_time - start_time} seconds")
if not ranked_chunks:
return {"answer": None, "quotes": None}
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
top_contents = [ranked_chunk.content for ranked_chunk in ranked_chunks]
logger.info(f"Semantic search took {sem_search_time - start_time} seconds")
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
logger.info(files_log_msg)

View File

@ -1,3 +1,4 @@
import argparse
import json
import requests
@ -8,35 +9,60 @@ from danswer.configs.constants import SOURCE_TYPE
if __name__ == "__main__":
previous_query = None
parser = argparse.ArgumentParser()
parser.add_argument(
"-k",
"--keyword-search",
action="store_true",
help="Use keyword search if set, semantic search otherwise",
)
parser.add_argument(
"-s",
"--source-types",
type=str,
help="Comma separated list of source types to filter by",
)
parser.add_argument("query", nargs="*", help="The query to process")
while True:
try:
keyword_search = False
query = input(
"\n\nAsk any question:\n - prefix with -k for keyword search\n - input an empty string to "
"rerun last query\n\t"
user_input = input(
"\n\nAsk any question:\n"
" - prefix with -s to add a filter by source(s)\n"
" - input an empty string to rerun last query\n\t"
)
if query.lower() in ["q", "quit", "exit", "exit()"]:
break
if query:
previous_query = query
if user_input:
previous_input = user_input
else:
if not previous_query:
print("No previous query")
if not previous_input:
print("No previous input")
continue
print(f"Re-executing previous question:\n\t{previous_query}")
query = previous_query
print(f"Re-executing previous question:\n\t{previous_input}")
user_input = previous_input
args = parser.parse_args(user_input.split())
keyword_search = args.keyword_search
source_types = args.source_types.split(",") if args.source_types else None
if source_types and len(source_types) == 1:
source_types = source_types[0]
query = " ".join(args.query)
endpoint = f"http://127.0.0.1:{APP_PORT}/direct-qa"
if query.startswith("-k "):
keyword_search = True
query = query[2:]
if args.keyword_search:
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
response = requests.post(
endpoint, json={"query": query, "collection": QDRANT_DEFAULT_COLLECTION}
)
query_json = {
"query": query,
"collection": QDRANT_DEFAULT_COLLECTION,
"filters": [{SOURCE_TYPE: source_types}],
}
response = requests.post(endpoint, json=query_json)
contents = json.loads(response.content)
if keyword_search:
if contents["results"]: