mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-12 05:49:36 +02:00
parent
7253316b9e
commit
123ec4342a
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0568ccf46a6b"
|
||||
down_revision = "e209dc5a8156"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -46,6 +46,7 @@ from danswer.tools.images.prompt import build_image_generation_user_prompt
|
||||
from danswer.tools.message import build_tool_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
|
||||
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
@ -101,6 +102,8 @@ class Answer:
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
# if set to True, then never use the LLMs provided tool-calling functonality
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
# Returns the full document sections text from the search tool
|
||||
return_contexts: bool = False,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
@ -133,6 +136,8 @@ class Answer:
|
||||
AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff
|
||||
] | None = None
|
||||
|
||||
self._return_contexts = return_contexts
|
||||
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
) -> None:
|
||||
@ -420,6 +425,12 @@ class Answer:
|
||||
]
|
||||
elif message.id == FINAL_CONTEXT_DOCUMENTS:
|
||||
final_context_docs = cast(list[LlmDoc], message.response)
|
||||
elif (
|
||||
message.id == SEARCH_DOC_CONTENT_ID
|
||||
and not self._return_contexts
|
||||
):
|
||||
continue
|
||||
|
||||
yield message
|
||||
else:
|
||||
# assumes all tool responses will come first, then the final answer
|
||||
|
@ -15,7 +15,7 @@ from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_time_to_system_prompt
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
|
||||
@ -25,7 +25,7 @@ def default_build_system_message(
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
if prompt_config.datetime_aware:
|
||||
system_prompt = add_time_to_system_prompt(system_prompt=system_prompt)
|
||||
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
|
||||
|
||||
if not system_prompt:
|
||||
return None
|
||||
|
@ -17,7 +17,7 @@ from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
|
||||
from danswer.prompts.prompt_utils import add_time_to_system_prompt
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
from danswer.prompts.prompt_utils import build_task_prompt_reminders
|
||||
from danswer.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
|
||||
@ -121,7 +121,7 @@ def build_citations_system_message(
|
||||
if prompt_config.include_citations:
|
||||
system_prompt += REQUIRE_CITATION_STATEMENT
|
||||
if prompt_config.datetime_aware:
|
||||
system_prompt = add_time_to_system_prompt(system_prompt=system_prompt)
|
||||
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
|
||||
|
||||
return SystemMessage(content=system_prompt)
|
||||
|
||||
|
@ -9,6 +9,7 @@ from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
|
||||
@ -35,6 +36,10 @@ def _build_weak_llm_quotes_prompt(
|
||||
task_prompt=prompt.task_prompt,
|
||||
user_query=question,
|
||||
)
|
||||
|
||||
if prompt.datetime_aware:
|
||||
prompt_str = add_date_time_to_prompt(prompt_str=prompt_str)
|
||||
|
||||
return HumanMessage(content=prompt_str)
|
||||
|
||||
|
||||
@ -62,6 +67,10 @@ def _build_strong_llm_quotes_prompt(
|
||||
user_query=question,
|
||||
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
|
||||
).strip()
|
||||
|
||||
if prompt.datetime_aware:
|
||||
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
|
||||
|
||||
return HumanMessage(content=full_prompt)
|
||||
|
||||
|
||||
|
@ -46,6 +46,7 @@ from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephr
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
@ -196,6 +197,7 @@ def stream_answer_objects(
|
||||
# for now, don't use tool calling for this flow, as we haven't
|
||||
# tested quotes with tool calling too much yet
|
||||
skip_explicit_tool_calling=True,
|
||||
return_contexts=query_req.return_contexts,
|
||||
)
|
||||
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
||||
dropped_inds: list[int] = []
|
||||
@ -245,6 +247,8 @@ def stream_answer_objects(
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
else:
|
||||
yield packet
|
||||
|
||||
|
@ -35,15 +35,15 @@ def get_current_llm_day_time(
|
||||
return f"{formatted_datetime}"
|
||||
|
||||
|
||||
def add_time_to_system_prompt(system_prompt: str) -> str:
|
||||
if DANSWER_DATETIME_REPLACEMENT in system_prompt:
|
||||
return system_prompt.replace(
|
||||
def add_date_time_to_prompt(prompt_str: str) -> str:
|
||||
if DANSWER_DATETIME_REPLACEMENT in prompt_str:
|
||||
return prompt_str.replace(
|
||||
DANSWER_DATETIME_REPLACEMENT,
|
||||
get_current_llm_day_time(full_sentence=False, include_day_of_week=True),
|
||||
)
|
||||
|
||||
if system_prompt:
|
||||
return system_prompt + ADDITIONAL_INFO.format(
|
||||
if prompt_str:
|
||||
return prompt_str + ADDITIONAL_INFO.format(
|
||||
datetime_info=get_current_llm_day_time()
|
||||
)
|
||||
else:
|
||||
|
@ -7,6 +7,8 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.models import DanswerContext
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
@ -30,6 +32,7 @@ from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
|
||||
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
|
||||
SEARCH_DOC_CONTENT_ID = "search_doc_content"
|
||||
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
||||
FINAL_CONTEXT_DOCUMENTS = "final_context_documents"
|
||||
|
||||
@ -221,6 +224,20 @@ class SearchTool(Tool):
|
||||
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
||||
),
|
||||
)
|
||||
yield ToolResponse(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=DanswerContexts(
|
||||
contexts=[
|
||||
DanswerContext(
|
||||
content=section.content,
|
||||
document_id=section.document_id,
|
||||
semantic_identifier=section.semantic_identifier,
|
||||
blurb=section.blurb,
|
||||
)
|
||||
for section in search_pipeline.reranked_sections
|
||||
]
|
||||
),
|
||||
)
|
||||
yield ToolResponse(
|
||||
id=SECTION_RELEVANCE_LIST_ID,
|
||||
response=search_pipeline.relevant_chunk_indices,
|
||||
|
@ -61,13 +61,8 @@ def read_questions(questions_file_path: str) -> list[dict]:
|
||||
return samples
|
||||
|
||||
|
||||
def main(questions_file: str, output_file: str, limit: int | None = None) -> None:
|
||||
samples = read_questions(questions_file)
|
||||
|
||||
if limit is not None:
|
||||
samples = samples[:limit]
|
||||
|
||||
response_dicts = []
|
||||
def get_relari_outputs(samples: list[dict]) -> list[dict]:
|
||||
relari_outputs = []
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
|
||||
for sample in samples:
|
||||
answer = get_answer_for_question(
|
||||
@ -75,22 +70,46 @@ def main(questions_file: str, output_file: str, limit: int | None = None) -> Non
|
||||
)
|
||||
assert answer.contexts
|
||||
|
||||
response_dict = {
|
||||
"question": sample["question"],
|
||||
"retrieved_contexts": [
|
||||
context.content for context in answer.contexts.contexts
|
||||
],
|
||||
"ground_truth_contexts": sample["ground_truth_contexts"],
|
||||
"answer": answer.answer,
|
||||
"ground_truths": sample["ground_truths"],
|
||||
}
|
||||
relari_outputs.append(
|
||||
{
|
||||
"label": sample["uid"],
|
||||
"question": sample["question"],
|
||||
"answer": answer.answer,
|
||||
"retrieved_context": [
|
||||
context.content for context in answer.contexts.contexts
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
response_dicts.append(response_dict)
|
||||
return relari_outputs
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as out_file:
|
||||
for response_dict in response_dicts:
|
||||
json_line = json.dumps(response_dict)
|
||||
out_file.write(json_line + "\n")
|
||||
|
||||
def write_output_file(relari_outputs: list[dict], output_file: str) -> None:
|
||||
with open(output_file, "w", encoding="utf-8") as file:
|
||||
for output in relari_outputs:
|
||||
file.write(json.dumps(output) + "\n")
|
||||
|
||||
|
||||
def main(questions_file: str, output_file: str, limit: int | None = None) -> None:
|
||||
samples = read_questions(questions_file)
|
||||
|
||||
if limit is not None:
|
||||
samples = samples[:limit]
|
||||
|
||||
# Use to be in this format but has since changed
|
||||
# response_dict = {
|
||||
# "question": sample["question"],
|
||||
# "retrieved_contexts": [
|
||||
# context.content for context in answer.contexts.contexts
|
||||
# ],
|
||||
# "ground_truth_contexts": sample["ground_truth_contexts"],
|
||||
# "answer": answer.answer,
|
||||
# "ground_truths": sample["ground_truths"],
|
||||
# }
|
||||
|
||||
relari_outputs = get_relari_outputs(samples=samples)
|
||||
|
||||
write_output_file(relari_outputs, output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user