mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-15 23:53:19 +02:00
DAN-71 Give back all ranked results from search (#34)
This commit is contained in:
@ -260,7 +260,7 @@ class OpenAICompletionQA(QAModel):
|
|||||||
if stream_answer_end(model_previous, event_text):
|
if stream_answer_end(model_previous, event_text):
|
||||||
found_answer_end = True
|
found_answer_end = True
|
||||||
continue
|
continue
|
||||||
yield {"answer data": event_text}
|
yield {"answer_data": event_text}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
|
@ -19,7 +19,6 @@ import json
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
|
||||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
|
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
|
||||||
@ -67,7 +66,6 @@ def warm_up_models() -> None:
|
|||||||
def semantic_reranking(
|
def semantic_reranking(
|
||||||
query: str,
|
query: str,
|
||||||
chunks: List[InferenceChunk],
|
chunks: List[InferenceChunk],
|
||||||
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
|
|
||||||
) -> List[InferenceChunk]:
|
) -> List[InferenceChunk]:
|
||||||
cross_encoder = get_default_reranking_model()
|
cross_encoder = get_default_reranking_model()
|
||||||
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
|
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
|
||||||
@ -75,11 +73,9 @@ def semantic_reranking(
|
|||||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||||
ranked_sim_scores, ranked_chunks = zip(*scored_results)
|
ranked_sim_scores, ranked_chunks = zip(*scored_results)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"Reranked similarity scores: {str(ranked_sim_scores)}")
|
||||||
f"Reranked similarity scores: {str(ranked_sim_scores[:filtered_result_set_size])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return ranked_chunks[:filtered_result_set_size]
|
return ranked_chunks
|
||||||
|
|
||||||
|
|
||||||
@log_function_time()
|
@log_function_time()
|
||||||
@ -88,7 +84,6 @@ def retrieve_ranked_documents(
|
|||||||
filters: list[DatastoreFilter] | None,
|
filters: list[DatastoreFilter] | None,
|
||||||
datastore: Datastore,
|
datastore: Datastore,
|
||||||
num_hits: int = NUM_RETURNED_HITS,
|
num_hits: int = NUM_RETURNED_HITS,
|
||||||
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
|
|
||||||
) -> List[InferenceChunk] | None:
|
) -> List[InferenceChunk] | None:
|
||||||
top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
|
top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
|
||||||
if not top_chunks:
|
if not top_chunks:
|
||||||
@ -97,7 +92,7 @@ def retrieve_ranked_documents(
|
|||||||
f"Semantic search returned no results with filters: {filters_log_msg}"
|
f"Semantic search returned no results with filters: {filters_log_msg}"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
|
ranked_chunks = semantic_reranking(query, top_chunks)
|
||||||
|
|
||||||
top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks]
|
top_docs = [ranked_chunk.source_links["0"] for ranked_chunk in ranked_chunks]
|
||||||
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
||||||
|
@ -10,6 +10,13 @@ class UserRoleResponse(BaseModel):
|
|||||||
role: str
|
role: str
|
||||||
|
|
||||||
|
|
||||||
|
class SearchDoc(BaseModel):
|
||||||
|
semantic_name: str
|
||||||
|
link: str
|
||||||
|
blurb: str
|
||||||
|
source_type: str
|
||||||
|
|
||||||
|
|
||||||
class QAQuestion(BaseModel):
|
class QAQuestion(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
collection: str
|
collection: str
|
||||||
@ -19,6 +26,7 @@ class QAQuestion(BaseModel):
|
|||||||
class QAResponse(BaseModel):
|
class QAResponse(BaseModel):
|
||||||
answer: str | None
|
answer: str | None
|
||||||
quotes: dict[str, dict[str, str | int | None]] | None
|
quotes: dict[str, dict[str, str | int | None]] | None
|
||||||
|
ranked_documents: list[SearchDoc] | None
|
||||||
|
|
||||||
|
|
||||||
class KeywordResponse(BaseModel):
|
class KeywordResponse(BaseModel):
|
||||||
|
@ -5,6 +5,7 @@ from danswer.auth.schemas import UserRole
|
|||||||
from danswer.auth.users import current_active_user
|
from danswer.auth.users import current_active_user
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
||||||
|
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||||
from danswer.configs.constants import CONTENT
|
from danswer.configs.constants import CONTENT
|
||||||
from danswer.configs.constants import SOURCE_LINKS
|
from danswer.configs.constants import SOURCE_LINKS
|
||||||
from danswer.datastores import create_datastore
|
from danswer.datastores import create_datastore
|
||||||
@ -16,6 +17,7 @@ from danswer.semantic_search.semantic_search import retrieve_ranked_documents
|
|||||||
from danswer.server.models import KeywordResponse
|
from danswer.server.models import KeywordResponse
|
||||||
from danswer.server.models import QAQuestion
|
from danswer.server.models import QAQuestion
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
|
from danswer.server.models import SearchDoc
|
||||||
from danswer.server.models import ServerStatus
|
from danswer.server.models import ServerStatus
|
||||||
from danswer.server.models import UserByEmail
|
from danswer.server.models import UserByEmail
|
||||||
from danswer.server.models import UserRoleResponse
|
from danswer.server.models import UserRoleResponse
|
||||||
@ -81,7 +83,7 @@ async def promote_admin(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/direct-qa", response_model=QAResponse)
|
@router.get("/direct-qa", response_model=QAResponse)
|
||||||
def direct_qa(question: QAQuestion = Depends()):
|
def direct_qa(question: QAQuestion = Depends()) -> QAResponse:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
query = question.query
|
query = question.query
|
||||||
@ -93,21 +95,31 @@ def direct_qa(question: QAQuestion = Depends()):
|
|||||||
query, filters, create_datastore(collection)
|
query, filters, create_datastore(collection)
|
||||||
)
|
)
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
return {"answer": None, "quotes": None}
|
return QAResponse(answer=None, quotes=None, ranked_documents=None)
|
||||||
|
|
||||||
|
top_docs = [
|
||||||
|
SearchDoc(
|
||||||
|
semantic_name=chunk.semantic_identifier,
|
||||||
|
link=chunk.source_links.get("0") if chunk.source_links else None,
|
||||||
|
blurb=chunk.blurb,
|
||||||
|
source_type=chunk.source_type,
|
||||||
|
)
|
||||||
|
for chunk in ranked_chunks
|
||||||
|
]
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model()
|
qa_model = get_default_backend_qa_model()
|
||||||
answer, quotes = qa_model.answer_question(query, ranked_chunks)
|
answer, quotes = qa_model.answer_question(
|
||||||
|
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
||||||
|
|
||||||
return QAResponse(answer=answer, quotes=quotes)
|
return QAResponse(answer=answer, quotes=quotes, ranked_documents=top_docs)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stream-direct-qa")
|
@router.get("/stream-direct-qa")
|
||||||
def stream_direct_qa(question: QAQuestion = Depends()):
|
def stream_direct_qa(question: QAQuestion = Depends()):
|
||||||
top_documents_key = "top_documents"
|
top_documents_key = "top_documents"
|
||||||
answer_key = "answer"
|
|
||||||
quotes_key = "quotes"
|
|
||||||
|
|
||||||
def stream_qa_portions():
|
def stream_qa_portions():
|
||||||
query = question.query
|
query = question.query
|
||||||
@ -120,26 +132,25 @@ def stream_direct_qa(question: QAQuestion = Depends()):
|
|||||||
)
|
)
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
return yield_json_line(
|
return yield_json_line(
|
||||||
{top_documents_key: None, answer_key: None, quotes_key: None}
|
QAResponse(answer=None, quotes=None, ranked_documents=None)
|
||||||
)
|
)
|
||||||
|
|
||||||
linked_chunks = [
|
top_docs = [
|
||||||
chunk
|
SearchDoc(
|
||||||
|
semantic_name=chunk.semantic_identifier,
|
||||||
|
link=chunk.source_links.get("0") if chunk.source_links else None,
|
||||||
|
blurb=chunk.blurb,
|
||||||
|
source_type=chunk.source_type,
|
||||||
|
)
|
||||||
for chunk in ranked_chunks
|
for chunk in ranked_chunks
|
||||||
if chunk.source_links and "0" in chunk.source_links
|
|
||||||
]
|
]
|
||||||
top_docs = {
|
top_docs_dict = {top_documents_key: [top_doc.json() for top_doc in top_docs]}
|
||||||
top_documents_key: {
|
yield yield_json_line(top_docs_dict)
|
||||||
"document section links": [
|
|
||||||
chunk.source_links["0"] for chunk in linked_chunks
|
|
||||||
],
|
|
||||||
"blurbs": [chunk.blurb for chunk in linked_chunks],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
yield yield_json_line(top_docs)
|
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model()
|
qa_model = get_default_backend_qa_model()
|
||||||
for response_dict in qa_model.answer_question_stream(query, ranked_chunks):
|
for response_dict in qa_model.answer_question_stream(
|
||||||
|
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
||||||
|
):
|
||||||
logger.debug(response_dict)
|
logger.debug(response_dict)
|
||||||
yield yield_json_line(response_dict)
|
yield yield_json_line(response_dict)
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import urllib
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from danswer.configs.app_configs import APP_PORT
|
from danswer.configs.app_configs import APP_PORT
|
||||||
@ -78,7 +79,9 @@ if __name__ == "__main__":
|
|||||||
"filters": [{SOURCE_TYPE: source_types}],
|
"filters": [{SOURCE_TYPE: source_types}],
|
||||||
}
|
}
|
||||||
if not args.stream:
|
if not args.stream:
|
||||||
response = requests.get(endpoint, json=query_json)
|
response = requests.get(
|
||||||
|
endpoint, params=urllib.parse.urlencode(query_json)
|
||||||
|
)
|
||||||
contents = json.loads(response.content)
|
contents = json.loads(response.content)
|
||||||
if keyword_search:
|
if keyword_search:
|
||||||
if contents["results"]:
|
if contents["results"]:
|
||||||
@ -106,15 +109,11 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
print("No quotes found")
|
print("No quotes found")
|
||||||
else:
|
else:
|
||||||
answer = ""
|
with requests.get(
|
||||||
with requests.get(endpoint, json=query_json, stream=True) as r:
|
endpoint, params=urllib.parse.urlencode(query_json), stream=True
|
||||||
|
) as r:
|
||||||
for json_response in r.iter_lines():
|
for json_response in r.iter_lines():
|
||||||
response_dict = json.loads(json_response.decode())
|
print(json.loads(json_response.decode()))
|
||||||
if "answer data" not in response_dict:
|
|
||||||
print(response_dict)
|
|
||||||
else:
|
|
||||||
answer += response_dict["answer data"]
|
|
||||||
print(answer)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed due to {e}, retrying")
|
print(f"Failed due to {e}, retrying")
|
||||||
|
Reference in New Issue
Block a user