import os import re from collections.abc import Callable from collections.abc import Iterator from collections.abc import Sequence from datetime import datetime from typing import Any from typing import cast from typing import Literal from typing import TypedDict from uuid import UUID from langchain_core.messages import BaseMessage from langchain_core.messages import HumanMessage from langgraph.types import StreamWriter from sqlalchemy.orm import Session from onyx.agents.agent_search.models import GraphConfig from onyx.agents.agent_search.models import GraphInputs from onyx.agents.agent_search.models import GraphPersistence from onyx.agents.agent_search.models import GraphSearchConfig from onyx.agents.agent_search.models import GraphTooling from onyx.agents.agent_search.shared_graph_utils.models import ( EntityRelationshipTermExtraction, ) from onyx.agents.agent_search.shared_graph_utils.models import PersonaPromptExpressions from onyx.chat.models import AnswerPacket from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import CitationConfig from onyx.chat.models import DocumentPruningConfig from onyx.chat.models import PromptConfig from onyx.chat.models import SectionRelevancePiece from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopReason from onyx.chat.models import StreamType from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from onyx.configs.constants import DEFAULT_PERSONA_ID from onyx.configs.constants import DISPATCH_SEP_CHAR from onyx.configs.constants import FORMAT_DOCS_SEPARATOR from onyx.context.search.enums import LLMEvaluationType from onyx.context.search.models import InferenceSection from onyx.context.search.models import RetrievalDetails from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager from onyx.db.persona import get_persona_by_id from onyx.db.persona import Persona from onyx.llm.interfaces import LLM from onyx.prompts.agent_search import ( ASSISTANT_SYSTEM_PROMPT_DEFAULT, ) from onyx.prompts.agent_search import ( ASSISTANT_SYSTEM_PROMPT_PERSONA, ) from onyx.prompts.agent_search import ( HISTORY_CONTEXT_SUMMARY_PROMPT, ) from onyx.prompts.prompt_utils import handle_onyx_date_awareness from onyx.tools.force import ForceUseTool from onyx.tools.tool_constructor import SearchToolConfig from onyx.tools.tool_implementations.search.search_tool import ( SEARCH_RESPONSE_SUMMARY_ID, ) from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.utils import explicit_tool_calling_supported BaseMessage_Content = str | list[str | dict[str, Any]] # Post-processing def format_docs(docs: Sequence[InferenceSection]) -> str: formatted_doc_list = [] for doc_num, doc in enumerate(docs): title: str | None = doc.center_chunk.title metadata: dict[str, str | list[str]] | None = ( doc.center_chunk.metadata if doc.center_chunk.metadata else None ) doc_str = f"**Document: D{doc_num + 1}**" if title: doc_str += f"\nTitle: {title}" if metadata: metadata_str = "" for key, value in metadata.items(): if isinstance(value, str): metadata_str += f" - {key}: {value}" elif isinstance(value, list): metadata_str += f" - {key}: {', '.join(value)}" doc_str += f"\nMetadata: {metadata_str}" doc_str += f"\nContent:\n{doc.combined_content}" formatted_doc_list.append(doc_str) return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list) def format_entity_term_extraction( entity_term_extraction_dict: EntityRelationshipTermExtraction, ) -> str: entities = entity_term_extraction_dict.entities terms = entity_term_extraction_dict.terms relationships = entity_term_extraction_dict.relationships entity_strs = ["\nEntities:\n"] for entity in entities: entity_str = f"{entity.entity_name} ({entity.entity_type})" entity_strs.append(entity_str) entity_str = "\n - ".join(entity_strs) relationship_strs = ["\n\nRelationships:\n"] for relationship in relationships: relationship_name = relationship.relationship_name relationship_type = relationship.relationship_type relationship_entities = relationship.relationship_entities relationship_str = ( f"""{relationship_name} ({relationship_type}): {relationship_entities}""" ) relationship_strs.append(relationship_str) relationship_str = "\n - ".join(relationship_strs) term_strs = ["\n\nTerms:\n"] for term in terms: term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}" term_strs.append(term_str) term_str = "\n - ".join(term_strs) return "\n".join(entity_strs + relationship_strs + term_strs) def get_test_config( db_session: Session, primary_llm: LLM, fast_llm: LLM, search_request: SearchRequest, use_agentic_search: bool = True, ) -> GraphConfig: persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session) document_pruning_config = DocumentPruningConfig( max_chunks=int( persona.num_chunks if persona.num_chunks is not None else MAX_CHUNKS_FED_TO_CHAT ), max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE, ) answer_style_config = AnswerStyleConfig( citation_config=CitationConfig( # The docs retrieved by this flow are already relevance-filtered all_docs_useful=True ), document_pruning_config=document_pruning_config, structured_response_format=None, ) search_tool_config = SearchToolConfig( answer_style_config=answer_style_config, document_pruning_config=document_pruning_config, retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True rerank_settings=None, # Can use this to change reranking model selected_sections=None, latest_query_files=None, bypass_acl=False, ) prompt_config = PromptConfig.from_model(persona.prompts[0]) search_tool = SearchTool( db_session=db_session, user=None, persona=persona, retrieval_options=search_tool_config.retrieval_options, prompt_config=prompt_config, llm=primary_llm, fast_llm=fast_llm, pruning_config=search_tool_config.document_pruning_config, answer_style_config=search_tool_config.answer_style_config, selected_sections=search_tool_config.selected_sections, chunks_above=search_tool_config.chunks_above, chunks_below=search_tool_config.chunks_below, full_doc=search_tool_config.full_doc, evaluation_type=( LLMEvaluationType.BASIC if persona.llm_relevance_filter else LLMEvaluationType.SKIP ), rerank_settings=search_tool_config.rerank_settings, bypass_acl=search_tool_config.bypass_acl, ) graph_inputs = GraphInputs( search_request=search_request, prompt_builder=AnswerPromptBuilder( user_message=HumanMessage(content=search_request.query), message_history=[], llm_config=primary_llm.config, raw_user_query=search_request.query, raw_user_uploaded_files=[], ), structured_response_format=answer_style_config.structured_response_format, ) using_tool_calling_llm = explicit_tool_calling_supported( primary_llm.config.model_provider, primary_llm.config.model_name ) graph_tooling = GraphTooling( primary_llm=primary_llm, fast_llm=fast_llm, search_tool=search_tool, tools=[search_tool], force_use_tool=ForceUseTool(force_use=False, tool_name=""), using_tool_calling_llm=using_tool_calling_llm, ) chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID") assert ( chat_session_id is not None ), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests" graph_persistence = GraphPersistence( db_session=db_session, chat_session_id=UUID(chat_session_id), message_id=1, ) search_behavior_config = GraphSearchConfig( use_agentic_search=use_agentic_search, skip_gen_ai_answer_generation=False, allow_refinement=True, ) graph_config = GraphConfig( inputs=graph_inputs, tooling=graph_tooling, persistence=graph_persistence, behavior=search_behavior_config, ) return graph_config def get_persona_agent_prompt_expressions( persona: Persona | None, ) -> PersonaPromptExpressions: if persona is None or len(persona.prompts) == 0: # TODO base_prompt should be None, but no time to properly fix return PersonaPromptExpressions( contextualized_prompt=ASSISTANT_SYSTEM_PROMPT_DEFAULT, base_prompt="" ) # Only a 1:1 mapping between personas and prompts currently prompt = persona.prompts[0] prompt_config = PromptConfig.from_model(prompt) datetime_aware_system_prompt = handle_onyx_date_awareness( prompt_str=prompt_config.system_prompt, prompt_config=prompt_config, add_additional_info_if_no_tag=prompt.datetime_aware, ) return PersonaPromptExpressions( contextualized_prompt=ASSISTANT_SYSTEM_PROMPT_PERSONA.format( persona_prompt=datetime_aware_system_prompt ), base_prompt=datetime_aware_system_prompt, ) def make_question_id(level: int, question_num: int) -> str: return f"{level}_{question_num}" def parse_question_id(question_id: str) -> tuple[int, int]: level, question_num = question_id.split("_") return int(level), int(question_num) def _dispatch_nonempty( content: str, dispatch_event: Callable[[str, int], None], sep_num: int ) -> None: """ Dispatch a content string if it is not empty using the given callback. This function is used in the context of dispatching some arbitrary number of similar objects which are separated by a separator during the LLM stream. The callback expects a sep_num denoting which object is being dispatched; these numbers go from 1 to however many strings the LLM decides to stream. """ if content != "": dispatch_event(content, sep_num) def dispatch_separated( tokens: Iterator[BaseMessage], dispatch_event: Callable[[str, int], None], sep: str = DISPATCH_SEP_CHAR, ) -> list[BaseMessage_Content]: num = 1 streamed_tokens: list[BaseMessage_Content] = [] for token in tokens: content = cast(str, token.content) if sep in content: sub_question_parts = content.split(sep) _dispatch_nonempty(sub_question_parts[0], dispatch_event, num) num += 1 _dispatch_nonempty( "".join(sub_question_parts[1:]).strip(), dispatch_event, num ) else: _dispatch_nonempty(content, dispatch_event, num) streamed_tokens.append(content) return streamed_tokens def dispatch_main_answer_stop_info(level: int, writer: StreamWriter) -> None: stop_event = StreamStopInfo( stop_reason=StreamStopReason.FINISHED, stream_type=StreamType.MAIN_ANSWER, level=level, ) write_custom_event("stream_finished", stop_event, writer) def retrieve_search_docs( search_tool: SearchTool, question: str ) -> list[InferenceSection]: retrieved_docs: list[InferenceSection] = [] # new db session to avoid concurrency issues with get_session_context_manager() as db_session: for tool_response in search_tool.run( query=question, force_no_rerank=True, alternate_db_session=db_session, ): # get retrieved docs to send to the rest of the graph if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID: response = cast(SearchResponseSummary, tool_response.response) retrieved_docs = response.top_sections break return retrieved_docs def get_answer_citation_ids(answer_str: str) -> list[int]: """ Extract citation numbers of format [D] from the answer string. """ citation_ids = re.findall(r"\[D(\d+)\]", answer_str) return list(set([(int(id) - 1) for id in citation_ids])) def summarize_history( history: str, question: str, persona_specification: str | None, llm: LLM ) -> str: history_context_prompt = remove_document_citations( HISTORY_CONTEXT_SUMMARY_PROMPT.format( persona_specification=persona_specification, question=question, history=history, ) ) history_response = llm.invoke(history_context_prompt) assert isinstance(history_response.content, str) return history_response.content # taken from langchain_core.runnables.schema # we don't use the one from their library because # it includes ids they generate class CustomStreamEvent(TypedDict): # Overwrite the event field to be more specific. event: Literal["on_custom_event"] # type: ignore[misc] """The event type.""" name: str """User defined name for the event.""" data: Any """The data associated with the event. Free form and can be anything.""" def write_custom_event( name: str, event: AnswerPacket, stream_writer: StreamWriter ) -> None: stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event)) def relevance_from_docs( relevant_docs: list[InferenceSection], ) -> list[SectionRelevancePiece]: return [ SectionRelevancePiece( relevant=True, content=doc.center_chunk.content, document_id=doc.center_chunk.document_id, chunk_id=doc.center_chunk.chunk_id, ) for doc in relevant_docs ] def get_langgraph_node_log_string( graph_component: str, node_name: str, node_start_time: datetime, result: str | None = None, ) -> str: duration = datetime.now() - node_start_time results_str = "" if result is None else f" -- Result: {result}" return f"{node_start_time} -- {graph_component} - {node_name} -- Time taken: {duration}{results_str}" def remove_document_citations(text: str) -> str: """ Removes citation expressions of format '[[D1]]()' from text. The number after D can vary. Args: text: Input text containing citations Returns: Text with citations removed """ # Pattern explanation: # \[(?:D|Q)?\d+\] matches: # \[ - literal [ character # (?:D|Q)? - optional D or Q character # \d+ - one or more digits # \] - literal ] character return re.sub(r"\[(?:D|Q)?\d+\]", "", text)