diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py index 620ef95a3..052b818e5 100644 --- a/backend/danswer/db/feedback.py +++ b/backend/danswer/db/feedback.py @@ -148,6 +148,23 @@ def update_query_event_retrieved_documents( db_session.commit() +def update_query_event_llm_answer( + db_session: Session, + llm_answer: str, + query_id: int, + user_id: UUID | None, +) -> None: + query_event = fetch_query_event_by_id(query_id, db_session) + + if user_id != query_event.user_id: + raise ValueError( + "User trying to update llm_answer on a query run by another user." + ) + + query_event.llm_answer = llm_answer + db_session.commit() + + def create_doc_retrieval_feedback( qa_event_id: int, document_id: str, diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index 489cb2eb6..bb04cb659 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -7,7 +7,7 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.constants import IGNORE_FOR_QA -from danswer.db.feedback import create_query_event +from danswer.db.feedback import update_query_event_llm_answer from danswer.db.models import User from danswer.direct_qa.factory import get_default_qa_model from danswer.direct_qa.interfaces import DanswerAnswerPiece @@ -160,6 +160,15 @@ def answer_qa_query( d_answer, quotes = None, None error_msg = f"Error occurred in call to LLM - {e}" # Used in the QAResponse + # update query event created by call to `danswer_search` with the LLM answer + if d_answer and d_answer.answer is not None: + update_query_event_llm_answer( + db_session=db_session, + llm_answer=d_answer.answer, + query_id=query_event_id, + user_id=None if user is None else user.id, + ) + if not real_time_flow and enable_reflexion and d_answer is not None: valid = False if d_answer.answer is not None: @@ -304,13 +313,12 @@ def answer_qa_query_stream( error = StreamingError(error="The LLM failed to produce a useable response") yield get_json_line(error.dict()) - query_event_id = create_query_event( - query=query, - search_type=question.search_type, - llm_answer=answer_so_far, - retrieved_document_ids=[doc.document_id for doc in top_docs], - user_id=None if user is None else user.id, + # update query event created by call to `danswer_search` with the LLM answer + update_query_event_llm_answer( db_session=db_session, + llm_answer=answer_so_far, + query_id=query_event_id, + user_id=None if user is None else user.id, ) yield get_json_line({"query_event_id": query_event_id})