DAN-71 Give back all ranked results from search (#34)

This commit is contained in:
Yuhong Sun 2023-05-12 15:47:58 -07:00 committed by GitHub
parent 66130c8845
commit ae2a1d3121
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 38 deletions

View File

@ -260,7 +260,7 @@ class OpenAICompletionQA(QAModel):
if stream_answer_end(model_previous, event_text):
found_answer_end = True
continue
yield {"answer data": event_text}
yield {"answer_data": event_text}
except Exception as e:
logger.exception(e)

View File

@ -19,7 +19,6 @@ import json
from typing import List
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
@ -67,7 +66,6 @@ def warm_up_models() -> None:
def semantic_reranking(
query: str,
chunks: List[InferenceChunk],
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
) -> List[InferenceChunk]:
cross_encoder = get_default_reranking_model()
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
@ -75,11 +73,9 @@ def semantic_reranking(
scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_chunks = zip(*scored_results)
logger.debug(
f"Reranked similarity scores: {str(ranked_sim_scores[:filtered_result_set_size])}"
)
logger.debug(f"Reranked similarity scores: {str(ranked_sim_scores)}")
return ranked_chunks[:filtered_result_set_size]
return ranked_chunks
@log_function_time()
@ -88,7 +84,6 @@ def retrieve_ranked_documents(
filters: list[DatastoreFilter] | None,
datastore: Datastore,
num_hits: int = NUM_RETURNED_HITS,
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
) -> List[InferenceChunk] | None:
top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
if not top_chunks:
@ -97,7 +92,7 @@ def retrieve_ranked_documents(
f"Semantic search returned no results with filters: {filters_log_msg}"
)
return None
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
ranked_chunks = semantic_reranking(query, top_chunks)
top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks]
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"

View File

@ -10,6 +10,13 @@ class UserRoleResponse(BaseModel):
role: str
class SearchDoc(BaseModel):
semantic_name: str
link: str
blurb: str
source_type: str
class QAQuestion(BaseModel):
query: str
collection: str
@ -19,6 +26,7 @@ class QAQuestion(BaseModel):
class QAResponse(BaseModel):
answer: str | None
quotes: dict[str, dict[str, str | int | None]] | None
ranked_documents: list[SearchDoc] | None
class KeywordResponse(BaseModel):

View File

@ -5,6 +5,7 @@ from danswer.auth.schemas import UserRole
from danswer.auth.users import current_active_user
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import KEYWORD_MAX_HITS
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.constants import CONTENT
from danswer.configs.constants import SOURCE_LINKS
from danswer.datastores import create_datastore
@ -16,6 +17,7 @@ from danswer.semantic_search.semantic_search import retrieve_ranked_documents
from danswer.server.models import KeywordResponse
from danswer.server.models import QAQuestion
from danswer.server.models import QAResponse
from danswer.server.models import SearchDoc
from danswer.server.models import ServerStatus
from danswer.server.models import UserByEmail
from danswer.server.models import UserRoleResponse
@ -81,7 +83,7 @@ async def promote_admin(
@router.get("/direct-qa", response_model=QAResponse)
def direct_qa(question: QAQuestion = Depends()):
def direct_qa(question: QAQuestion = Depends()) -> QAResponse:
start_time = time.time()
query = question.query
@ -93,21 +95,31 @@ def direct_qa(question: QAQuestion = Depends()):
query, filters, create_datastore(collection)
)
if not ranked_chunks:
return {"answer": None, "quotes": None}
return QAResponse(answer=None, quotes=None, ranked_documents=None)
top_docs = [
SearchDoc(
semantic_name=chunk.semantic_identifier,
link=chunk.source_links.get("0") if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
)
for chunk in ranked_chunks
]
qa_model = get_default_backend_qa_model()
answer, quotes = qa_model.answer_question(query, ranked_chunks)
answer, quotes = qa_model.answer_question(
query, ranked_chunks[:NUM_RERANKED_RESULTS]
)
logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(answer=answer, quotes=quotes)
return QAResponse(answer=answer, quotes=quotes, ranked_documents=top_docs)
@router.get("/stream-direct-qa")
def stream_direct_qa(question: QAQuestion = Depends()):
top_documents_key = "top_documents"
answer_key = "answer"
quotes_key = "quotes"
def stream_qa_portions():
query = question.query
@ -120,26 +132,25 @@ def stream_direct_qa(question: QAQuestion = Depends()):
)
if not ranked_chunks:
return yield_json_line(
{top_documents_key: None, answer_key: None, quotes_key: None}
QAResponse(answer=None, quotes=None, ranked_documents=None)
)
linked_chunks = [
chunk
top_docs = [
SearchDoc(
semantic_name=chunk.semantic_identifier,
link=chunk.source_links.get("0") if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
)
for chunk in ranked_chunks
if chunk.source_links and "0" in chunk.source_links
]
top_docs = {
top_documents_key: {
"document section links": [
chunk.source_links["0"] for chunk in linked_chunks
],
"blurbs": [chunk.blurb for chunk in linked_chunks],
}
}
yield yield_json_line(top_docs)
top_docs_dict = {top_documents_key: [top_doc.json() for top_doc in top_docs]}
yield yield_json_line(top_docs_dict)
qa_model = get_default_backend_qa_model()
for response_dict in qa_model.answer_question_stream(query, ranked_chunks):
for response_dict in qa_model.answer_question_stream(
query, ranked_chunks[:NUM_RERANKED_RESULTS]
):
logger.debug(response_dict)
yield yield_json_line(response_dict)

View File

@ -1,5 +1,6 @@
import argparse
import json
import urllib
import requests
from danswer.configs.app_configs import APP_PORT
@ -78,7 +79,9 @@ if __name__ == "__main__":
"filters": [{SOURCE_TYPE: source_types}],
}
if not args.stream:
response = requests.get(endpoint, json=query_json)
response = requests.get(
endpoint, params=urllib.parse.urlencode(query_json)
)
contents = json.loads(response.content)
if keyword_search:
if contents["results"]:
@ -106,15 +109,11 @@ if __name__ == "__main__":
else:
print("No quotes found")
else:
answer = ""
with requests.get(endpoint, json=query_json, stream=True) as r:
with requests.get(
endpoint, params=urllib.parse.urlencode(query_json), stream=True
) as r:
for json_response in r.iter_lines():
response_dict = json.loads(json_response.decode())
if "answer data" not in response_dict:
print(response_dict)
else:
answer += response_dict["answer data"]
print(answer)
print(json.loads(json_response.decode()))
except Exception as e:
print(f"Failed due to {e}, retrying")