2025-02-03 20:10:51 -08:00

434 lines
15 KiB
Python

import os
import re
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from typing import cast
from typing import Literal
from typing import TypedDict
from uuid import UUID
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import PersonaPromptExpressions
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import DISPATCH_SEP_CHAR
from onyx.configs.constants import FORMAT_DOCS_SEPARATOR
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.interfaces import LLM
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.prompts.agent_search import (
HISTORY_CONTEXT_SUMMARY_PROMPT,
)
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
BaseMessage_Content = str | list[str | dict[str, Any]]
# Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for doc_num, doc in enumerate(docs):
title: str | None = doc.center_chunk.title
metadata: dict[str, str | list[str]] | None = (
doc.center_chunk.metadata if doc.center_chunk.metadata else None
)
doc_str = f"**Document: D{doc_num + 1}**"
if title:
doc_str += f"\nTitle: {title}"
if metadata:
metadata_str = ""
for key, value in metadata.items():
if isinstance(value, str):
metadata_str += f" - {key}: {value}"
elif isinstance(value, list):
metadata_str += f" - {key}: {', '.join(value)}"
doc_str += f"\nMetadata: {metadata_str}"
doc_str += f"\nContent:\n{doc.combined_content}"
formatted_doc_list.append(doc_str)
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
def format_entity_term_extraction(
entity_term_extraction_dict: EntityRelationshipTermExtraction,
) -> str:
entities = entity_term_extraction_dict.entities
terms = entity_term_extraction_dict.terms
relationships = entity_term_extraction_dict.relationships
entity_strs = ["\nEntities:\n"]
for entity in entities:
entity_str = f"{entity.entity_name} ({entity.entity_type})"
entity_strs.append(entity_str)
entity_str = "\n - ".join(entity_strs)
relationship_strs = ["\n\nRelationships:\n"]
for relationship in relationships:
relationship_name = relationship.relationship_name
relationship_type = relationship.relationship_type
relationship_entities = relationship.relationship_entities
relationship_str = (
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
)
relationship_strs.append(relationship_str)
relationship_str = "\n - ".join(relationship_strs)
term_strs = ["\n\nTerms:\n"]
for term in terms:
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
term_strs.append(term_str)
term_str = "\n - ".join(term_strs)
return "\n".join(entity_strs + relationship_strs + term_strs)
def get_test_config(
db_session: Session,
primary_llm: LLM,
fast_llm: LLM,
search_request: SearchRequest,
use_agentic_search: bool = True,
) -> GraphConfig:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
persona.num_chunks
if persona.num_chunks is not None
else MAX_CHUNKS_FED_TO_CHAT
),
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
# The docs retrieved by this flow are already relevance-filtered
all_docs_useful=True
),
document_pruning_config=document_pruning_config,
structured_response_format=None,
)
search_tool_config = SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
rerank_settings=None, # Can use this to change reranking model
selected_sections=None,
latest_query_files=None,
bypass_acl=False,
)
prompt_config = PromptConfig.from_model(persona.prompts[0])
search_tool = SearchTool(
db_session=db_session,
user=None,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=primary_llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
graph_inputs = GraphInputs(
search_request=search_request,
prompt_builder=AnswerPromptBuilder(
user_message=HumanMessage(content=search_request.query),
message_history=[],
llm_config=primary_llm.config,
raw_user_query=search_request.query,
raw_user_uploaded_files=[],
),
structured_response_format=answer_style_config.structured_response_format,
)
using_tool_calling_llm = explicit_tool_calling_supported(
primary_llm.config.model_provider, primary_llm.config.model_name
)
graph_tooling = GraphTooling(
primary_llm=primary_llm,
fast_llm=fast_llm,
search_tool=search_tool,
tools=[search_tool],
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
using_tool_calling_llm=using_tool_calling_llm,
)
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
assert (
chat_session_id is not None
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
graph_persistence = GraphPersistence(
db_session=db_session,
chat_session_id=UUID(chat_session_id),
message_id=1,
)
search_behavior_config = GraphSearchConfig(
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=False,
allow_refinement=True,
)
graph_config = GraphConfig(
inputs=graph_inputs,
tooling=graph_tooling,
persistence=graph_persistence,
behavior=search_behavior_config,
)
return graph_config
def get_persona_agent_prompt_expressions(
persona: Persona | None,
) -> PersonaPromptExpressions:
if persona is None or len(persona.prompts) == 0:
# TODO base_prompt should be None, but no time to properly fix
return PersonaPromptExpressions(
contextualized_prompt=ASSISTANT_SYSTEM_PROMPT_DEFAULT, base_prompt=""
)
# Only a 1:1 mapping between personas and prompts currently
prompt = persona.prompts[0]
prompt_config = PromptConfig.from_model(prompt)
datetime_aware_system_prompt = handle_onyx_date_awareness(
prompt_str=prompt_config.system_prompt,
prompt_config=prompt_config,
add_additional_info_if_no_tag=prompt.datetime_aware,
)
return PersonaPromptExpressions(
contextualized_prompt=ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=datetime_aware_system_prompt
),
base_prompt=datetime_aware_system_prompt,
)
def make_question_id(level: int, question_num: int) -> str:
return f"{level}_{question_num}"
def parse_question_id(question_id: str) -> tuple[int, int]:
level, question_num = question_id.split("_")
return int(level), int(question_num)
def _dispatch_nonempty(
content: str, dispatch_event: Callable[[str, int], None], sep_num: int
) -> None:
"""
Dispatch a content string if it is not empty using the given callback.
This function is used in the context of dispatching some arbitrary number
of similar objects which are separated by a separator during the LLM stream.
The callback expects a sep_num denoting which object is being dispatched; these
numbers go from 1 to however many strings the LLM decides to stream.
"""
if content != "":
dispatch_event(content, sep_num)
def dispatch_separated(
tokens: Iterator[BaseMessage],
dispatch_event: Callable[[str, int], None],
sep: str = DISPATCH_SEP_CHAR,
) -> list[BaseMessage_Content]:
num = 1
streamed_tokens: list[BaseMessage_Content] = []
for token in tokens:
content = cast(str, token.content)
if sep in content:
sub_question_parts = content.split(sep)
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
num += 1
_dispatch_nonempty(
"".join(sub_question_parts[1:]).strip(), dispatch_event, num
)
else:
_dispatch_nonempty(content, dispatch_event, num)
streamed_tokens.append(content)
return streamed_tokens
def dispatch_main_answer_stop_info(level: int, writer: StreamWriter) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.MAIN_ANSWER,
level=level,
)
write_custom_event("stream_finished", stop_event, writer)
def retrieve_search_docs(
search_tool: SearchTool, question: str
) -> list[InferenceSection]:
retrieved_docs: list[InferenceSection] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
force_no_rerank=True,
alternate_db_session=db_session,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
return retrieved_docs
def get_answer_citation_ids(answer_str: str) -> list[int]:
"""
Extract citation numbers of format [D<number>] from the answer string.
"""
citation_ids = re.findall(r"\[D(\d+)\]", answer_str)
return list(set([(int(id) - 1) for id in citation_ids]))
def summarize_history(
history: str, question: str, persona_specification: str | None, llm: LLM
) -> str:
history_context_prompt = remove_document_citations(
HISTORY_CONTEXT_SUMMARY_PROMPT.format(
persona_specification=persona_specification,
question=question,
history=history,
)
)
history_response = llm.invoke(history_context_prompt)
assert isinstance(history_response.content, str)
return history_response.content
# taken from langchain_core.runnables.schema
# we don't use the one from their library because
# it includes ids they generate
class CustomStreamEvent(TypedDict):
# Overwrite the event field to be more specific.
event: Literal["on_custom_event"] # type: ignore[misc]
"""The event type."""
name: str
"""User defined name for the event."""
data: Any
"""The data associated with the event. Free form and can be anything."""
def write_custom_event(
name: str, event: AnswerPacket, stream_writer: StreamWriter
) -> None:
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))
def relevance_from_docs(
relevant_docs: list[InferenceSection],
) -> list[SectionRelevancePiece]:
return [
SectionRelevancePiece(
relevant=True,
content=doc.center_chunk.content,
document_id=doc.center_chunk.document_id,
chunk_id=doc.center_chunk.chunk_id,
)
for doc in relevant_docs
]
def get_langgraph_node_log_string(
graph_component: str,
node_name: str,
node_start_time: datetime,
result: str | None = None,
) -> str:
duration = datetime.now() - node_start_time
results_str = "" if result is None else f" -- Result: {result}"
return f"{node_start_time} -- {graph_component} - {node_name} -- Time taken: {duration}{results_str}"
def remove_document_citations(text: str) -> str:
"""
Removes citation expressions of format '[[D1]]()' from text.
The number after D can vary.
Args:
text: Input text containing citations
Returns:
Text with citations removed
"""
# Pattern explanation:
# \[(?:D|Q)?\d+\] matches:
# \[ - literal [ character
# (?:D|Q)? - optional D or Q character
# \d+ - one or more digits
# \] - literal ] character
return re.sub(r"\[(?:D|Q)?\d+\]", "", text)