DAN-71 Give back all ranked results from search (#34)

This commit is contained in:
Yuhong Sun
2023-05-12 15:47:58 -07:00
committed by GitHub
parent 66130c8845
commit ae2a1d3121
5 changed files with 51 additions and 38 deletions

View File

@ -260,7 +260,7 @@ class OpenAICompletionQA(QAModel):
if stream_answer_end(model_previous, event_text): if stream_answer_end(model_previous, event_text):
found_answer_end = True found_answer_end = True
continue continue
yield {"answer data": event_text} yield {"answer_data": event_text}
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)

View File

@ -19,7 +19,6 @@ import json
from typing import List from typing import List
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL from danswer.configs.model_configs import CROSS_ENCODER_MODEL
@ -67,7 +66,6 @@ def warm_up_models() -> None:
def semantic_reranking( def semantic_reranking(
query: str, query: str,
chunks: List[InferenceChunk], chunks: List[InferenceChunk],
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
) -> List[InferenceChunk]: ) -> List[InferenceChunk]:
cross_encoder = get_default_reranking_model() cross_encoder = get_default_reranking_model()
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
@ -75,11 +73,9 @@ def semantic_reranking(
scored_results.sort(key=lambda x: x[0], reverse=True) scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_chunks = zip(*scored_results) ranked_sim_scores, ranked_chunks = zip(*scored_results)
logger.debug( logger.debug(f"Reranked similarity scores: {str(ranked_sim_scores)}")
f"Reranked similarity scores: {str(ranked_sim_scores[:filtered_result_set_size])}"
)
return ranked_chunks[:filtered_result_set_size] return ranked_chunks
@log_function_time() @log_function_time()
@ -88,7 +84,6 @@ def retrieve_ranked_documents(
filters: list[DatastoreFilter] | None, filters: list[DatastoreFilter] | None,
datastore: Datastore, datastore: Datastore,
num_hits: int = NUM_RETURNED_HITS, num_hits: int = NUM_RETURNED_HITS,
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
) -> List[InferenceChunk] | None: ) -> List[InferenceChunk] | None:
top_chunks = datastore.semantic_retrieval(query, filters, num_hits) top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
if not top_chunks: if not top_chunks:
@ -97,7 +92,7 @@ def retrieve_ranked_documents(
f"Semantic search returned no results with filters: {filters_log_msg}" f"Semantic search returned no results with filters: {filters_log_msg}"
) )
return None return None
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size) ranked_chunks = semantic_reranking(query, top_chunks)
top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks] top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks]
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}" files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"

View File

@ -10,6 +10,13 @@ class UserRoleResponse(BaseModel):
role: str role: str
class SearchDoc(BaseModel):
semantic_name: str
link: str
blurb: str
source_type: str
class QAQuestion(BaseModel): class QAQuestion(BaseModel):
query: str query: str
collection: str collection: str
@ -19,6 +26,7 @@ class QAQuestion(BaseModel):
class QAResponse(BaseModel): class QAResponse(BaseModel):
answer: str | None answer: str | None
quotes: dict[str, dict[str, str | int | None]] | None quotes: dict[str, dict[str, str | int | None]] | None
ranked_documents: list[SearchDoc] | None
class KeywordResponse(BaseModel): class KeywordResponse(BaseModel):

View File

@ -5,6 +5,7 @@ from danswer.auth.schemas import UserRole
from danswer.auth.users import current_active_user from danswer.auth.users import current_active_user
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import KEYWORD_MAX_HITS from danswer.configs.app_configs import KEYWORD_MAX_HITS
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.constants import CONTENT from danswer.configs.constants import CONTENT
from danswer.configs.constants import SOURCE_LINKS from danswer.configs.constants import SOURCE_LINKS
from danswer.datastores import create_datastore from danswer.datastores import create_datastore
@ -16,6 +17,7 @@ from danswer.semantic_search.semantic_search import retrieve_ranked_documents
from danswer.server.models import KeywordResponse from danswer.server.models import KeywordResponse
from danswer.server.models import QAQuestion from danswer.server.models import QAQuestion
from danswer.server.models import QAResponse from danswer.server.models import QAResponse
from danswer.server.models import SearchDoc
from danswer.server.models import ServerStatus from danswer.server.models import ServerStatus
from danswer.server.models import UserByEmail from danswer.server.models import UserByEmail
from danswer.server.models import UserRoleResponse from danswer.server.models import UserRoleResponse
@ -81,7 +83,7 @@ async def promote_admin(
@router.get("/direct-qa", response_model=QAResponse) @router.get("/direct-qa", response_model=QAResponse)
def direct_qa(question: QAQuestion = Depends()): def direct_qa(question: QAQuestion = Depends()) -> QAResponse:
start_time = time.time() start_time = time.time()
query = question.query query = question.query
@ -93,21 +95,31 @@ def direct_qa(question: QAQuestion = Depends()):
query, filters, create_datastore(collection) query, filters, create_datastore(collection)
) )
if not ranked_chunks: if not ranked_chunks:
return {"answer": None, "quotes": None} return QAResponse(answer=None, quotes=None, ranked_documents=None)
top_docs = [
SearchDoc(
semantic_name=chunk.semantic_identifier,
link=chunk.source_links.get("0") if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
)
for chunk in ranked_chunks
]
qa_model = get_default_backend_qa_model() qa_model = get_default_backend_qa_model()
answer, quotes = qa_model.answer_question(query, ranked_chunks) answer, quotes = qa_model.answer_question(
query, ranked_chunks[:NUM_RERANKED_RESULTS]
)
logger.info(f"Total QA took {time.time() - start_time} seconds") logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(answer=answer, quotes=quotes) return QAResponse(answer=answer, quotes=quotes, ranked_documents=top_docs)
@router.get("/stream-direct-qa") @router.get("/stream-direct-qa")
def stream_direct_qa(question: QAQuestion = Depends()): def stream_direct_qa(question: QAQuestion = Depends()):
top_documents_key = "top_documents" top_documents_key = "top_documents"
answer_key = "answer"
quotes_key = "quotes"
def stream_qa_portions(): def stream_qa_portions():
query = question.query query = question.query
@ -120,26 +132,25 @@ def stream_direct_qa(question: QAQuestion = Depends()):
) )
if not ranked_chunks: if not ranked_chunks:
return yield_json_line( return yield_json_line(
{top_documents_key: None, answer_key: None, quotes_key: None} QAResponse(answer=None, quotes=None, ranked_documents=None)
) )
linked_chunks = [ top_docs = [
chunk SearchDoc(
semantic_name=chunk.semantic_identifier,
link=chunk.source_links.get("0") if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
)
for chunk in ranked_chunks for chunk in ranked_chunks
if chunk.source_links and "0" in chunk.source_links
] ]
top_docs = { top_docs_dict = {top_documents_key: [top_doc.json() for top_doc in top_docs]}
top_documents_key: { yield yield_json_line(top_docs_dict)
"document section links": [
chunk.source_links["0"] for chunk in linked_chunks
],
"blurbs": [chunk.blurb for chunk in linked_chunks],
}
}
yield yield_json_line(top_docs)
qa_model = get_default_backend_qa_model() qa_model = get_default_backend_qa_model()
for response_dict in qa_model.answer_question_stream(query, ranked_chunks): for response_dict in qa_model.answer_question_stream(
query, ranked_chunks[:NUM_RERANKED_RESULTS]
):
logger.debug(response_dict) logger.debug(response_dict)
yield yield_json_line(response_dict) yield yield_json_line(response_dict)

View File

@ -1,5 +1,6 @@
import argparse import argparse
import json import json
import urllib
import requests import requests
from danswer.configs.app_configs import APP_PORT from danswer.configs.app_configs import APP_PORT
@ -78,7 +79,9 @@ if __name__ == "__main__":
"filters": [{SOURCE_TYPE: source_types}], "filters": [{SOURCE_TYPE: source_types}],
} }
if not args.stream: if not args.stream:
response = requests.get(endpoint, json=query_json) response = requests.get(
endpoint, params=urllib.parse.urlencode(query_json)
)
contents = json.loads(response.content) contents = json.loads(response.content)
if keyword_search: if keyword_search:
if contents["results"]: if contents["results"]:
@ -106,15 +109,11 @@ if __name__ == "__main__":
else: else:
print("No quotes found") print("No quotes found")
else: else:
answer = "" with requests.get(
with requests.get(endpoint, json=query_json, stream=True) as r: endpoint, params=urllib.parse.urlencode(query_json), stream=True
) as r:
for json_response in r.iter_lines(): for json_response in r.iter_lines():
response_dict = json.loads(json_response.decode()) print(json.loads(json_response.decode()))
if "answer data" not in response_dict:
print(response_dict)
else:
answer += response_dict["answer data"]
print(answer)
except Exception as e: except Exception as e:
print(f"Failed due to {e}, retrying") print(f"Failed due to {e}, retrying")