mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-01 00:18:18 +02:00
DAN-71 Give back all ranked results from search (#34)
This commit is contained in:
parent
66130c8845
commit
ae2a1d3121
@ -260,7 +260,7 @@ class OpenAICompletionQA(QAModel):
|
||||
if stream_answer_end(model_previous, event_text):
|
||||
found_answer_end = True
|
||||
continue
|
||||
yield {"answer data": event_text}
|
||||
yield {"answer_data": event_text}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
|
@ -19,7 +19,6 @@ import json
|
||||
from typing import List
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
|
||||
@ -67,7 +66,6 @@ def warm_up_models() -> None:
|
||||
def semantic_reranking(
|
||||
query: str,
|
||||
chunks: List[InferenceChunk],
|
||||
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
|
||||
) -> List[InferenceChunk]:
|
||||
cross_encoder = get_default_reranking_model()
|
||||
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
|
||||
@ -75,11 +73,9 @@ def semantic_reranking(
|
||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||
ranked_sim_scores, ranked_chunks = zip(*scored_results)
|
||||
|
||||
logger.debug(
|
||||
f"Reranked similarity scores: {str(ranked_sim_scores[:filtered_result_set_size])}"
|
||||
)
|
||||
logger.debug(f"Reranked similarity scores: {str(ranked_sim_scores)}")
|
||||
|
||||
return ranked_chunks[:filtered_result_set_size]
|
||||
return ranked_chunks
|
||||
|
||||
|
||||
@log_function_time()
|
||||
@ -88,7 +84,6 @@ def retrieve_ranked_documents(
|
||||
filters: list[DatastoreFilter] | None,
|
||||
datastore: Datastore,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
|
||||
) -> List[InferenceChunk] | None:
|
||||
top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
|
||||
if not top_chunks:
|
||||
@ -97,7 +92,7 @@ def retrieve_ranked_documents(
|
||||
f"Semantic search returned no results with filters: {filters_log_msg}"
|
||||
)
|
||||
return None
|
||||
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
|
||||
ranked_chunks = semantic_reranking(query, top_chunks)
|
||||
|
||||
top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks]
|
||||
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
||||
|
@ -10,6 +10,13 @@ class UserRoleResponse(BaseModel):
|
||||
role: str
|
||||
|
||||
|
||||
class SearchDoc(BaseModel):
|
||||
semantic_name: str
|
||||
link: str
|
||||
blurb: str
|
||||
source_type: str
|
||||
|
||||
|
||||
class QAQuestion(BaseModel):
|
||||
query: str
|
||||
collection: str
|
||||
@ -19,6 +26,7 @@ class QAQuestion(BaseModel):
|
||||
class QAResponse(BaseModel):
|
||||
answer: str | None
|
||||
quotes: dict[str, dict[str, str | int | None]] | None
|
||||
ranked_documents: list[SearchDoc] | None
|
||||
|
||||
|
||||
class KeywordResponse(BaseModel):
|
||||
|
@ -5,6 +5,7 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.users import current_active_user
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.constants import CONTENT
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.datastores import create_datastore
|
||||
@ -16,6 +17,7 @@ from danswer.semantic_search.semantic_search import retrieve_ranked_documents
|
||||
from danswer.server.models import KeywordResponse
|
||||
from danswer.server.models import QAQuestion
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import SearchDoc
|
||||
from danswer.server.models import ServerStatus
|
||||
from danswer.server.models import UserByEmail
|
||||
from danswer.server.models import UserRoleResponse
|
||||
@ -81,7 +83,7 @@ async def promote_admin(
|
||||
|
||||
|
||||
@router.get("/direct-qa", response_model=QAResponse)
|
||||
def direct_qa(question: QAQuestion = Depends()):
|
||||
def direct_qa(question: QAQuestion = Depends()) -> QAResponse:
|
||||
start_time = time.time()
|
||||
|
||||
query = question.query
|
||||
@ -93,21 +95,31 @@ def direct_qa(question: QAQuestion = Depends()):
|
||||
query, filters, create_datastore(collection)
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return {"answer": None, "quotes": None}
|
||||
return QAResponse(answer=None, quotes=None, ranked_documents=None)
|
||||
|
||||
top_docs = [
|
||||
SearchDoc(
|
||||
semantic_name=chunk.semantic_identifier,
|
||||
link=chunk.source_links.get("0") if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
)
|
||||
for chunk in ranked_chunks
|
||||
]
|
||||
|
||||
qa_model = get_default_backend_qa_model()
|
||||
answer, quotes = qa_model.answer_question(query, ranked_chunks)
|
||||
answer, quotes = qa_model.answer_question(
|
||||
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
||||
)
|
||||
|
||||
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
||||
|
||||
return QAResponse(answer=answer, quotes=quotes)
|
||||
return QAResponse(answer=answer, quotes=quotes, ranked_documents=top_docs)
|
||||
|
||||
|
||||
@router.get("/stream-direct-qa")
|
||||
def stream_direct_qa(question: QAQuestion = Depends()):
|
||||
top_documents_key = "top_documents"
|
||||
answer_key = "answer"
|
||||
quotes_key = "quotes"
|
||||
|
||||
def stream_qa_portions():
|
||||
query = question.query
|
||||
@ -120,26 +132,25 @@ def stream_direct_qa(question: QAQuestion = Depends()):
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return yield_json_line(
|
||||
{top_documents_key: None, answer_key: None, quotes_key: None}
|
||||
QAResponse(answer=None, quotes=None, ranked_documents=None)
|
||||
)
|
||||
|
||||
linked_chunks = [
|
||||
chunk
|
||||
top_docs = [
|
||||
SearchDoc(
|
||||
semantic_name=chunk.semantic_identifier,
|
||||
link=chunk.source_links.get("0") if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
)
|
||||
for chunk in ranked_chunks
|
||||
if chunk.source_links and "0" in chunk.source_links
|
||||
]
|
||||
top_docs = {
|
||||
top_documents_key: {
|
||||
"document section links": [
|
||||
chunk.source_links["0"] for chunk in linked_chunks
|
||||
],
|
||||
"blurbs": [chunk.blurb for chunk in linked_chunks],
|
||||
}
|
||||
}
|
||||
yield yield_json_line(top_docs)
|
||||
top_docs_dict = {top_documents_key: [top_doc.json() for top_doc in top_docs]}
|
||||
yield yield_json_line(top_docs_dict)
|
||||
|
||||
qa_model = get_default_backend_qa_model()
|
||||
for response_dict in qa_model.answer_question_stream(query, ranked_chunks):
|
||||
for response_dict in qa_model.answer_question_stream(
|
||||
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
||||
):
|
||||
logger.debug(response_dict)
|
||||
yield yield_json_line(response_dict)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import json
|
||||
import urllib
|
||||
|
||||
import requests
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
@ -78,7 +79,9 @@ if __name__ == "__main__":
|
||||
"filters": [{SOURCE_TYPE: source_types}],
|
||||
}
|
||||
if not args.stream:
|
||||
response = requests.get(endpoint, json=query_json)
|
||||
response = requests.get(
|
||||
endpoint, params=urllib.parse.urlencode(query_json)
|
||||
)
|
||||
contents = json.loads(response.content)
|
||||
if keyword_search:
|
||||
if contents["results"]:
|
||||
@ -106,15 +109,11 @@ if __name__ == "__main__":
|
||||
else:
|
||||
print("No quotes found")
|
||||
else:
|
||||
answer = ""
|
||||
with requests.get(endpoint, json=query_json, stream=True) as r:
|
||||
with requests.get(
|
||||
endpoint, params=urllib.parse.urlencode(query_json), stream=True
|
||||
) as r:
|
||||
for json_response in r.iter_lines():
|
||||
response_dict = json.loads(json_response.decode())
|
||||
if "answer data" not in response_dict:
|
||||
print(response_dict)
|
||||
else:
|
||||
answer += response_dict["answer data"]
|
||||
print(answer)
|
||||
print(json.loads(json_response.decode()))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed due to {e}, retrying")
|
||||
|
Loading…
x
Reference in New Issue
Block a user