mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-04 17:00:24 +02:00
371 lines
13 KiB
Python
371 lines
13 KiB
Python
import ast
|
|
import json
|
|
import re
|
|
from collections.abc import Callable
|
|
from collections.abc import Iterator
|
|
from collections.abc import Sequence
|
|
from datetime import datetime
|
|
from datetime import timedelta
|
|
from typing import Any
|
|
from typing import cast
|
|
from uuid import UUID
|
|
|
|
from langchain_core.callbacks.manager import dispatch_custom_event
|
|
from langchain_core.messages import BaseMessage
|
|
from langchain_core.messages import HumanMessage
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
|
from onyx.agents.agent_search.shared_graph_utils.models import (
|
|
EntityRelationshipTermExtraction,
|
|
)
|
|
from onyx.agents.agent_search.shared_graph_utils.models import PersonaExpressions
|
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
|
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
|
)
|
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
|
ASSISTANT_SYSTEM_PROMPT_PERSONA,
|
|
)
|
|
from onyx.agents.agent_search.shared_graph_utils.prompts import DATE_PROMPT
|
|
from onyx.agents.agent_search.shared_graph_utils.prompts import (
|
|
HISTORY_CONTEXT_SUMMARY_PROMPT,
|
|
)
|
|
from onyx.chat.models import AnswerStyleConfig
|
|
from onyx.chat.models import CitationConfig
|
|
from onyx.chat.models import DocumentPruningConfig
|
|
from onyx.chat.models import PromptConfig
|
|
from onyx.chat.models import StreamStopInfo
|
|
from onyx.chat.models import StreamStopReason
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
|
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
|
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
|
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
|
from onyx.configs.constants import DISPATCH_SEP_CHAR
|
|
from onyx.context.search.enums import LLMEvaluationType
|
|
from onyx.context.search.models import InferenceSection
|
|
from onyx.context.search.models import RetrievalDetails
|
|
from onyx.context.search.models import SearchRequest
|
|
from onyx.db.engine import get_session_context_manager
|
|
from onyx.db.persona import get_persona_by_id
|
|
from onyx.db.persona import Persona
|
|
from onyx.llm.interfaces import LLM
|
|
from onyx.tools.force import ForceUseTool
|
|
from onyx.tools.tool_constructor import SearchToolConfig
|
|
from onyx.tools.tool_implementations.search.search_tool import (
|
|
SEARCH_RESPONSE_SUMMARY_ID,
|
|
)
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
|
|
|
BaseMessage_Content = str | list[str | dict[str, Any]]
|
|
|
|
|
|
def normalize_whitespace(text: str) -> str:
|
|
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
|
import re
|
|
|
|
return re.sub(r"\s+", " ", text.strip())
|
|
|
|
|
|
# Post-processing
|
|
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
|
formatted_doc_list = []
|
|
|
|
for doc_nr, doc in enumerate(docs):
|
|
formatted_doc_list.append(f"Document D{doc_nr + 1}:\n{doc.combined_content}")
|
|
|
|
return "\n\n".join(formatted_doc_list)
|
|
|
|
|
|
def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str:
|
|
formatted_doc_list = []
|
|
|
|
for _, doc in enumerate(docs):
|
|
formatted_doc_list.append(f"\n...{doc.combined_content}\n")
|
|
|
|
return "\n\n".join(formatted_doc_list)
|
|
|
|
|
|
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
|
# Remove any prefixes/labels before the actual JSON content
|
|
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
|
|
|
|
# Remove markdown code block markers and any newline prefixes
|
|
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
|
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
|
cleaned_string = " ".join(cleaned_string.split())
|
|
|
|
# Try parsing with json.loads first, fall back to ast.literal_eval
|
|
try:
|
|
return json.loads(cleaned_string)
|
|
except json.JSONDecodeError:
|
|
try:
|
|
return ast.literal_eval(cleaned_string)
|
|
except (ValueError, SyntaxError) as e:
|
|
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
|
|
|
|
|
|
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
|
# Remove markdown code block markers and any newline prefixes
|
|
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
|
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
|
cleaned_string = " ".join(cleaned_string.split())
|
|
# Parse the cleaned string into a Python dictionary
|
|
return json.loads(cleaned_string)
|
|
|
|
|
|
def format_entity_term_extraction(
|
|
entity_term_extraction_dict: EntityRelationshipTermExtraction,
|
|
) -> str:
|
|
entities = entity_term_extraction_dict.entities
|
|
terms = entity_term_extraction_dict.terms
|
|
relationships = entity_term_extraction_dict.relationships
|
|
|
|
entity_strs = ["\nEntities:\n"]
|
|
for entity in entities:
|
|
entity_str = f"{entity.entity_name} ({entity.entity_type})"
|
|
entity_strs.append(entity_str)
|
|
|
|
entity_str = "\n - ".join(entity_strs)
|
|
|
|
relationship_strs = ["\n\nRelationships:\n"]
|
|
for relationship in relationships:
|
|
relationship_name = relationship.relationship_name
|
|
relationship_type = relationship.relationship_type
|
|
relationship_entities = relationship.relationship_entities
|
|
relationship_str = (
|
|
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
|
|
)
|
|
relationship_strs.append(relationship_str)
|
|
|
|
relationship_str = "\n - ".join(relationship_strs)
|
|
|
|
term_strs = ["\n\nTerms:\n"]
|
|
for term in terms:
|
|
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
|
|
term_strs.append(term_str)
|
|
|
|
term_str = "\n - ".join(term_strs)
|
|
|
|
return "\n".join(entity_strs + relationship_strs + term_strs)
|
|
|
|
|
|
def _format_time_delta(time: timedelta) -> str:
|
|
seconds_from_start = f"{((time).seconds):03d}"
|
|
microseconds_from_start = f"{((time).microseconds):06d}"
|
|
return f"{seconds_from_start}.{microseconds_from_start}"
|
|
|
|
|
|
def generate_log_message(
|
|
message: str,
|
|
node_start_time: datetime,
|
|
graph_start_time: datetime | None = None,
|
|
) -> str:
|
|
current_time = datetime.now()
|
|
|
|
if graph_start_time is not None:
|
|
graph_time_str = _format_time_delta(current_time - graph_start_time)
|
|
else:
|
|
graph_time_str = "N/A"
|
|
|
|
node_time_str = _format_time_delta(current_time - node_start_time)
|
|
|
|
return f"{graph_time_str} ({node_time_str} s): {message}"
|
|
|
|
|
|
def get_test_config(
|
|
db_session: Session,
|
|
primary_llm: LLM,
|
|
fast_llm: LLM,
|
|
search_request: SearchRequest,
|
|
use_agentic_search: bool = True,
|
|
) -> tuple[AgentSearchConfig, SearchTool]:
|
|
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
|
document_pruning_config = DocumentPruningConfig(
|
|
max_chunks=int(
|
|
persona.num_chunks
|
|
if persona.num_chunks is not None
|
|
else MAX_CHUNKS_FED_TO_CHAT
|
|
),
|
|
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
|
|
)
|
|
|
|
answer_style_config = AnswerStyleConfig(
|
|
citation_config=CitationConfig(
|
|
# The docs retrieved by this flow are already relevance-filtered
|
|
all_docs_useful=True
|
|
),
|
|
document_pruning_config=document_pruning_config,
|
|
structured_response_format=None,
|
|
)
|
|
|
|
search_tool_config = SearchToolConfig(
|
|
answer_style_config=answer_style_config,
|
|
document_pruning_config=document_pruning_config,
|
|
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
|
|
rerank_settings=None, # Can use this to change reranking model
|
|
selected_sections=None,
|
|
latest_query_files=None,
|
|
bypass_acl=False,
|
|
)
|
|
|
|
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
|
|
|
search_tool = SearchTool(
|
|
db_session=db_session,
|
|
user=None,
|
|
persona=persona,
|
|
retrieval_options=search_tool_config.retrieval_options,
|
|
prompt_config=prompt_config,
|
|
llm=primary_llm,
|
|
fast_llm=fast_llm,
|
|
pruning_config=search_tool_config.document_pruning_config,
|
|
answer_style_config=search_tool_config.answer_style_config,
|
|
selected_sections=search_tool_config.selected_sections,
|
|
chunks_above=search_tool_config.chunks_above,
|
|
chunks_below=search_tool_config.chunks_below,
|
|
full_doc=search_tool_config.full_doc,
|
|
evaluation_type=(
|
|
LLMEvaluationType.BASIC
|
|
if persona.llm_relevance_filter
|
|
else LLMEvaluationType.SKIP
|
|
),
|
|
rerank_settings=search_tool_config.rerank_settings,
|
|
bypass_acl=search_tool_config.bypass_acl,
|
|
)
|
|
|
|
config = AgentSearchConfig(
|
|
search_request=search_request,
|
|
primary_llm=primary_llm,
|
|
fast_llm=fast_llm,
|
|
search_tool=search_tool,
|
|
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
|
prompt_builder=AnswerPromptBuilder(
|
|
user_message=HumanMessage(content=search_request.query),
|
|
message_history=[],
|
|
llm_config=primary_llm.config,
|
|
raw_user_query=search_request.query,
|
|
raw_user_uploaded_files=[],
|
|
),
|
|
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
|
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
|
|
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
|
|
message_id=1,
|
|
use_agentic_persistence=True,
|
|
db_session=db_session,
|
|
tools=[search_tool],
|
|
use_agentic_search=use_agentic_search,
|
|
)
|
|
|
|
return config, search_tool
|
|
|
|
|
|
def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions:
|
|
if persona is None:
|
|
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
|
persona_base = ""
|
|
else:
|
|
persona_base = "\n".join([x.system_prompt for x in persona.prompts])
|
|
|
|
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
|
persona_prompt=persona_base
|
|
)
|
|
return PersonaExpressions(
|
|
contextualized_prompt=persona_prompt, base_prompt=persona_base
|
|
)
|
|
|
|
|
|
def make_question_id(level: int, question_nr: int) -> str:
|
|
return f"{level}_{question_nr}"
|
|
|
|
|
|
def parse_question_id(question_id: str) -> tuple[int, int]:
|
|
level, question_nr = question_id.split("_")
|
|
return int(level), int(question_nr)
|
|
|
|
|
|
def _dispatch_nonempty(
|
|
content: str, dispatch_event: Callable[[str, int], None], num: int
|
|
) -> None:
|
|
if content != "":
|
|
dispatch_event(content, num)
|
|
|
|
|
|
def dispatch_separated(
|
|
tokens: Iterator[BaseMessage],
|
|
dispatch_event: Callable[[str, int], None],
|
|
sep: str = DISPATCH_SEP_CHAR,
|
|
) -> list[BaseMessage_Content]:
|
|
num = 1
|
|
streamed_tokens: list[BaseMessage_Content] = []
|
|
for token in tokens:
|
|
content = cast(str, token.content)
|
|
if sep in content:
|
|
sub_question_parts = content.split(sep)
|
|
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
|
|
num += 1
|
|
_dispatch_nonempty(
|
|
"".join(sub_question_parts[1:]).strip(), dispatch_event, num
|
|
)
|
|
else:
|
|
_dispatch_nonempty(content, dispatch_event, num)
|
|
streamed_tokens.append(content)
|
|
|
|
return streamed_tokens
|
|
|
|
|
|
def dispatch_main_answer_stop_info(level: int) -> None:
|
|
stop_event = StreamStopInfo(
|
|
stop_reason=StreamStopReason.FINISHED,
|
|
stream_type="main_answer",
|
|
level=level,
|
|
)
|
|
dispatch_custom_event("stream_finished", stop_event)
|
|
|
|
|
|
def get_today_prompt() -> str:
|
|
return DATE_PROMPT.format(date=datetime.now().strftime("%A, %B %d, %Y"))
|
|
|
|
|
|
def retrieve_search_docs(
|
|
search_tool: SearchTool, question: str
|
|
) -> list[InferenceSection]:
|
|
retrieved_docs: list[InferenceSection] = []
|
|
|
|
# new db session to avoid concurrency issues
|
|
with get_session_context_manager() as db_session:
|
|
for tool_response in search_tool.run(
|
|
query=question,
|
|
force_no_rerank=True,
|
|
alternate_db_session=db_session,
|
|
):
|
|
# get retrieved docs to send to the rest of the graph
|
|
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
|
response = cast(SearchResponseSummary, tool_response.response)
|
|
retrieved_docs = response.top_sections
|
|
break
|
|
|
|
return retrieved_docs
|
|
|
|
|
|
def get_answer_citation_ids(answer_str: str) -> list[int]:
|
|
citation_ids = re.findall(r"\[\[D(\d+)\]\]", answer_str)
|
|
return list(set([(int(id) - 1) for id in citation_ids]))
|
|
|
|
|
|
def summarize_history(
|
|
history: str, question: str, persona_specification: str, model: LLM
|
|
) -> str:
|
|
history_context_prompt = HISTORY_CONTEXT_SUMMARY_PROMPT.format(
|
|
persona_specification=persona_specification, question=question, history=history
|
|
)
|
|
|
|
history_response = model.invoke(history_context_prompt)
|
|
|
|
if isinstance(history_response.content, str):
|
|
history_context_response_str = history_response.content
|
|
else:
|
|
history_context_response_str = ""
|
|
|
|
return history_context_response_str
|