Also includes some bugfixes
This commit is contained in:
Yuhong Sun 2024-06-22 18:52:48 -07:00 committed by GitHub
parent 7253316b9e
commit 123ec4342a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 92 additions and 32 deletions

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0568ccf46a6b"
down_revision = "e209dc5a8156"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -46,6 +46,7 @@ from danswer.tools.images.prompt import build_image_generation_user_prompt
from danswer.tools.message import build_tool_message
from danswer.tools.message import ToolCallSummary
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
@ -101,6 +102,8 @@ class Answer:
force_use_tool: ForceUseTool | None = None,
# if set to True, then never use the LLMs provided tool-calling functonality
skip_explicit_tool_calling: bool = False,
# Returns the full document sections text from the search tool
return_contexts: bool = False,
) -> None:
if single_message_history and message_history:
raise ValueError(
@ -133,6 +136,8 @@ class Answer:
AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff
] | None = None
self._return_contexts = return_contexts
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
) -> None:
@ -420,6 +425,12 @@ class Answer:
]
elif message.id == FINAL_CONTEXT_DOCUMENTS:
final_context_docs = cast(list[LlmDoc], message.response)
elif (
message.id == SEARCH_DOC_CONTENT_ID
and not self._return_contexts
):
continue
yield message
else:
# assumes all tool responses will come first, then the final answer

View File

@ -15,7 +15,7 @@ from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_history_to_basemessages
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_time_to_system_prompt
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow
from danswer.tools.message import ToolCallSummary
@ -25,7 +25,7 @@ def default_build_system_message(
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
if prompt_config.datetime_aware:
system_prompt = add_time_to_system_prompt(system_prompt=system_prompt)
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
if not system_prompt:
return None

View File

@ -17,7 +17,7 @@ from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
from danswer.prompts.prompt_utils import add_time_to_system_prompt
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.prompts.prompt_utils import build_task_prompt_reminders
from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
@ -121,7 +121,7 @@ def build_citations_system_message(
if prompt_config.include_citations:
system_prompt += REQUIRE_CITATION_STATEMENT
if prompt_config.datetime_aware:
system_prompt = add_time_to_system_prompt(system_prompt=system_prompt)
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
return SystemMessage(content=system_prompt)

View File

@ -9,6 +9,7 @@ from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.search.models import InferenceChunk
@ -35,6 +36,10 @@ def _build_weak_llm_quotes_prompt(
task_prompt=prompt.task_prompt,
user_query=question,
)
if prompt.datetime_aware:
prompt_str = add_date_time_to_prompt(prompt_str=prompt_str)
return HumanMessage(content=prompt_str)
@ -62,6 +67,10 @@ def _build_strong_llm_quotes_prompt(
user_query=question,
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
).strip()
if prompt.datetime_aware:
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
return HumanMessage(content=full_prompt)

View File

@ -46,6 +46,7 @@ from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephr
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.utils import get_json_line
from danswer.tools.force import ForceUseTool
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
@ -196,6 +197,7 @@ def stream_answer_objects(
# for now, don't use tool calling for this flow, as we haven't
# tested quotes with tool calling too much yet
skip_explicit_tool_calling=True,
return_contexts=query_req.return_contexts,
)
# won't be any ImageGenerationDisplay responses since that tool is never passed in
dropped_inds: list[int] = []
@ -245,6 +247,8 @@ def stream_answer_objects(
)
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
elif packet.id == SEARCH_DOC_CONTENT_ID:
yield packet.response
else:
yield packet

View File

@ -35,15 +35,15 @@ def get_current_llm_day_time(
return f"{formatted_datetime}"
def add_time_to_system_prompt(system_prompt: str) -> str:
if DANSWER_DATETIME_REPLACEMENT in system_prompt:
return system_prompt.replace(
def add_date_time_to_prompt(prompt_str: str) -> str:
if DANSWER_DATETIME_REPLACEMENT in prompt_str:
return prompt_str.replace(
DANSWER_DATETIME_REPLACEMENT,
get_current_llm_day_time(full_sentence=False, include_day_of_week=True),
)
if system_prompt:
return system_prompt + ADDITIONAL_INFO.format(
if prompt_str:
return prompt_str + ADDITIONAL_INFO.format(
datetime_info=get_current_llm_day_time()
)
else:

View File

@ -7,6 +7,8 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.models import DanswerContext
from danswer.chat.models import DanswerContexts
from danswer.chat.models import LlmDoc
from danswer.db.models import Persona
from danswer.db.models import User
@ -30,6 +32,7 @@ from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
FINAL_CONTEXT_DOCUMENTS = "final_context_documents"
@ -221,6 +224,20 @@ class SearchTool(Tool):
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
),
)
yield ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=DanswerContexts(
contexts=[
DanswerContext(
content=section.content,
document_id=section.document_id,
semantic_identifier=section.semantic_identifier,
blurb=section.blurb,
)
for section in search_pipeline.reranked_sections
]
),
)
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,
response=search_pipeline.relevant_chunk_indices,

View File

@ -61,13 +61,8 @@ def read_questions(questions_file_path: str) -> list[dict]:
return samples
def main(questions_file: str, output_file: str, limit: int | None = None) -> None:
samples = read_questions(questions_file)
if limit is not None:
samples = samples[:limit]
response_dicts = []
def get_relari_outputs(samples: list[dict]) -> list[dict]:
relari_outputs = []
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
for sample in samples:
answer = get_answer_for_question(
@ -75,22 +70,46 @@ def main(questions_file: str, output_file: str, limit: int | None = None) -> Non
)
assert answer.contexts
response_dict = {
"question": sample["question"],
"retrieved_contexts": [
context.content for context in answer.contexts.contexts
],
"ground_truth_contexts": sample["ground_truth_contexts"],
"answer": answer.answer,
"ground_truths": sample["ground_truths"],
}
relari_outputs.append(
{
"label": sample["uid"],
"question": sample["question"],
"answer": answer.answer,
"retrieved_context": [
context.content for context in answer.contexts.contexts
],
}
)
response_dicts.append(response_dict)
return relari_outputs
with open(output_file, "w", encoding="utf-8") as out_file:
for response_dict in response_dicts:
json_line = json.dumps(response_dict)
out_file.write(json_line + "\n")
def write_output_file(relari_outputs: list[dict], output_file: str) -> None:
with open(output_file, "w", encoding="utf-8") as file:
for output in relari_outputs:
file.write(json.dumps(output) + "\n")
def main(questions_file: str, output_file: str, limit: int | None = None) -> None:
samples = read_questions(questions_file)
if limit is not None:
samples = samples[:limit]
# Use to be in this format but has since changed
# response_dict = {
# "question": sample["question"],
# "retrieved_contexts": [
# context.content for context in answer.contexts.contexts
# ],
# "ground_truth_contexts": sample["ground_truth_contexts"],
# "answer": answer.answer,
# "ground_truths": sample["ground_truths"],
# }
relari_outputs = get_relari_outputs(samples=samples)
write_output_file(relari_outputs, output_file)
if __name__ == "__main__":