2025-02-03 20:10:50 -08:00

371 lines
13 KiB
Python

import ast
import json
import re
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import cast
from uuid import UUID
from langchain_core.callbacks.manager import dispatch_custom_event
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import PersonaExpressions
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import DATE_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
HISTORY_CONTEXT_SUMMARY_PROMPT,
)
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import DISPATCH_SEP_CHAR
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
BaseMessage_Content = str | list[str | dict[str, Any]]
def normalize_whitespace(text: str) -> str:
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
import re
return re.sub(r"\s+", " ", text.strip())
# Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for doc_nr, doc in enumerate(docs):
formatted_doc_list.append(f"Document D{doc_nr + 1}:\n{doc.combined_content}")
return "\n\n".join(formatted_doc_list)
def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for _, doc in enumerate(docs):
formatted_doc_list.append(f"\n...{doc.combined_content}\n")
return "\n\n".join(formatted_doc_list)
def clean_and_parse_list_string(json_string: str) -> list[dict]:
# Remove any prefixes/labels before the actual JSON content
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Try parsing with json.loads first, fall back to ast.literal_eval
try:
return json.loads(cleaned_string)
except json.JSONDecodeError:
try:
return ast.literal_eval(cleaned_string)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return json.loads(cleaned_string)
def format_entity_term_extraction(
entity_term_extraction_dict: EntityRelationshipTermExtraction,
) -> str:
entities = entity_term_extraction_dict.entities
terms = entity_term_extraction_dict.terms
relationships = entity_term_extraction_dict.relationships
entity_strs = ["\nEntities:\n"]
for entity in entities:
entity_str = f"{entity.entity_name} ({entity.entity_type})"
entity_strs.append(entity_str)
entity_str = "\n - ".join(entity_strs)
relationship_strs = ["\n\nRelationships:\n"]
for relationship in relationships:
relationship_name = relationship.relationship_name
relationship_type = relationship.relationship_type
relationship_entities = relationship.relationship_entities
relationship_str = (
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
)
relationship_strs.append(relationship_str)
relationship_str = "\n - ".join(relationship_strs)
term_strs = ["\n\nTerms:\n"]
for term in terms:
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
term_strs.append(term_str)
term_str = "\n - ".join(term_strs)
return "\n".join(entity_strs + relationship_strs + term_strs)
def _format_time_delta(time: timedelta) -> str:
seconds_from_start = f"{((time).seconds):03d}"
microseconds_from_start = f"{((time).microseconds):06d}"
return f"{seconds_from_start}.{microseconds_from_start}"
def generate_log_message(
message: str,
node_start_time: datetime,
graph_start_time: datetime | None = None,
) -> str:
current_time = datetime.now()
if graph_start_time is not None:
graph_time_str = _format_time_delta(current_time - graph_start_time)
else:
graph_time_str = "N/A"
node_time_str = _format_time_delta(current_time - node_start_time)
return f"{graph_time_str} ({node_time_str} s): {message}"
def get_test_config(
db_session: Session,
primary_llm: LLM,
fast_llm: LLM,
search_request: SearchRequest,
use_agentic_search: bool = True,
) -> tuple[AgentSearchConfig, SearchTool]:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
persona.num_chunks
if persona.num_chunks is not None
else MAX_CHUNKS_FED_TO_CHAT
),
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
# The docs retrieved by this flow are already relevance-filtered
all_docs_useful=True
),
document_pruning_config=document_pruning_config,
structured_response_format=None,
)
search_tool_config = SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
rerank_settings=None, # Can use this to change reranking model
selected_sections=None,
latest_query_files=None,
bypass_acl=False,
)
prompt_config = PromptConfig.from_model(persona.prompts[0])
search_tool = SearchTool(
db_session=db_session,
user=None,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=primary_llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
config = AgentSearchConfig(
search_request=search_request,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_tool=search_tool,
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
prompt_builder=AnswerPromptBuilder(
user_message=HumanMessage(content=search_request.query),
message_history=[],
llm_config=primary_llm.config,
raw_user_query=search_request.query,
raw_user_uploaded_files=[],
),
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
message_id=1,
use_agentic_persistence=True,
db_session=db_session,
tools=[search_tool],
use_agentic_search=use_agentic_search,
)
return config, search_tool
def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions:
if persona is None:
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
persona_base = ""
else:
persona_base = "\n".join([x.system_prompt for x in persona.prompts])
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_base
)
return PersonaExpressions(
contextualized_prompt=persona_prompt, base_prompt=persona_base
)
def make_question_id(level: int, question_nr: int) -> str:
return f"{level}_{question_nr}"
def parse_question_id(question_id: str) -> tuple[int, int]:
level, question_nr = question_id.split("_")
return int(level), int(question_nr)
def _dispatch_nonempty(
content: str, dispatch_event: Callable[[str, int], None], num: int
) -> None:
if content != "":
dispatch_event(content, num)
def dispatch_separated(
tokens: Iterator[BaseMessage],
dispatch_event: Callable[[str, int], None],
sep: str = DISPATCH_SEP_CHAR,
) -> list[BaseMessage_Content]:
num = 1
streamed_tokens: list[BaseMessage_Content] = []
for token in tokens:
content = cast(str, token.content)
if sep in content:
sub_question_parts = content.split(sep)
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
num += 1
_dispatch_nonempty(
"".join(sub_question_parts[1:]).strip(), dispatch_event, num
)
else:
_dispatch_nonempty(content, dispatch_event, num)
streamed_tokens.append(content)
return streamed_tokens
def dispatch_main_answer_stop_info(level: int) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type="main_answer",
level=level,
)
dispatch_custom_event("stream_finished", stop_event)
def get_today_prompt() -> str:
return DATE_PROMPT.format(date=datetime.now().strftime("%A, %B %d, %Y"))
def retrieve_search_docs(
search_tool: SearchTool, question: str
) -> list[InferenceSection]:
retrieved_docs: list[InferenceSection] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
force_no_rerank=True,
alternate_db_session=db_session,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
return retrieved_docs
def get_answer_citation_ids(answer_str: str) -> list[int]:
citation_ids = re.findall(r"\[\[D(\d+)\]\]", answer_str)
return list(set([(int(id) - 1) for id in citation_ids]))
def summarize_history(
history: str, question: str, persona_specification: str, model: LLM
) -> str:
history_context_prompt = HISTORY_CONTEXT_SUMMARY_PROMPT.format(
persona_specification=persona_specification, question=question, history=history
)
history_response = model.invoke(history_context_prompt)
if isinstance(history_response.content, str):
history_context_response_str = history_response.content
else:
history_context_response_str = ""
return history_context_response_str