From bd3b1943c4b358186e33b32a2a171d46ff806d37 Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Thu, 30 Jan 2025 09:46:53 -0800 Subject: [PATCH] WIP PR comments --- .../nodes/answer_check.py | 8 +-- .../nodes/format_answer.py | 4 +- .../states.py | 2 +- .../nodes/generate_initial_answer.py | 13 ++-- .../nodes/generate_refined_answer.py | 2 +- .../nodes/refined_sub_question_creation.py | 4 +- .../edges.py | 5 +- .../nodes/expand_queries.py | 7 +- .../nodes/verification_kickoff.py | 10 +-- .../operations.py | 24 +++---- .../states.py | 2 +- backend/onyx/agents/agent_search/models.py | 70 ++++++++++++++++++- .../agent_search/shared_graph_utils/models.py | 2 +- .../agent_search/shared_graph_utils/utils.py | 15 ++-- backend/onyx/configs/constants.py | 1 + .../regression/answer_quality/agent_test.py | 6 +- 16 files changed, 111 insertions(+), 64 deletions(-) diff --git a/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/nodes/answer_check.py b/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/nodes/answer_check.py index 39b053c67..a58debba2 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/nodes/answer_check.py +++ b/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/nodes/answer_check.py @@ -12,7 +12,6 @@ from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subg QACheckUpdate, ) from onyx.agents.agent_search.models import AgentSearchConfig -from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id @@ -25,7 +24,7 @@ def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckU if state.answer == UNKNOWN_ANSWER: now_end = datetime.now() return QACheckUpdate( - answer_quality=SUB_CHECK_NO, + answer_quality=False, log_messages=[ f"{now_start} -- Answer check SQ-{level}-{question_num} - unknown answer, Time taken: {now_end - now_start}" ], @@ -47,11 +46,12 @@ def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckU ) ) - quality_str = merge_message_runs(response, chunk_separator="")[0].content + quality_str: str = merge_message_runs(response, chunk_separator="")[0].content + answer_quality = "yes" in quality_str.lower() now_end = datetime.now() return QACheckUpdate( - answer_quality=quality_str, + answer_quality=answer_quality, log_messages=[ f"""{now_start} -- Answer check SQ-{level}-{question_num} - Answer quality: {quality_str}, Time taken: {now_end - now_start}""" diff --git a/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/nodes/format_answer.py b/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/nodes/format_answer.py index 92f1d5b8c..469e67f52 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/nodes/format_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/nodes/format_answer.py @@ -15,9 +15,7 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput: QuestionAnswerResults( question=state.question, question_id=state.question_id, - quality=state.answer_quality - if hasattr(state, "answer_quality") - else "No", + verified_high_quality=state.answer_quality, answer=state.answer, expanded_retrieval_results=state.expanded_retrieval_results, documents=state.documents, diff --git a/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/states.py b/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/states.py index 4c03aa43f..9a814c9cb 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/states.py +++ b/backend/onyx/agents/agent_search/deep_search_a/initial__individual_sub_answer__subgraph/states.py @@ -17,7 +17,7 @@ from onyx.context.search.models import InferenceSection ## Update States class QACheckUpdate(BaseModel): - answer_quality: str = "" + answer_quality: bool = False log_messages: list[str] = [] diff --git a/backend/onyx/agents/agent_search/deep_search_a/initial__retrieval_sub_answers__subgraph/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search_a/initial__retrieval_sub_answers__subgraph/nodes/generate_initial_answer.py index eeb17e803..e427b208f 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/initial__retrieval_sub_answers__subgraph/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/initial__retrieval_sub_answers__subgraph/nodes/generate_initial_answer.py @@ -45,7 +45,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( dispatch_main_answer_stop_info, ) from onyx.agents.agent_search.shared_graph_utils.utils import format_docs -from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import ExtendedToolResponse from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS @@ -136,9 +135,8 @@ def generate_initial_answer( for decomp_answer_result in decomp_answer_results: decomp_questions.append(decomp_answer_result.question) - _, question_nr = parse_question_id(decomp_answer_result.question_id) if ( - decomp_answer_result.quality.lower().startswith("yes") + decomp_answer_result.verified_high_quality and len(decomp_answer_result.answer) > 0 and decomp_answer_result.answer != UNKNOWN_ANSWER ): @@ -151,15 +149,12 @@ def generate_initial_answer( ) sub_question_nr += 1 - if len(good_qa_list) > 0: - sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) - else: - sub_question_answer_str = "" - # Determine which base prompt to use given the sub-question information if len(good_qa_list) > 0: + sub_question_answer_str = "\n\n------\n\n".join(good_qa_list) base_prompt = INITIAL_RAG_PROMPT else: + sub_question_answer_str = "" base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS model = agent_a_config.fast_llm @@ -182,7 +177,7 @@ def generate_initial_answer( answered_sub_questions=remove_document_citations( sub_question_answer_str ), - relevant_docs=format_docs(relevant_docs), + relevant_docs=doc_context, persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt, history=prompt_enrichment_components.history, date_prompt=prompt_enrichment_components.date_str, diff --git a/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/generate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/generate_refined_answer.py index 0c0cfd865..114b904df 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/generate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/generate_refined_answer.py @@ -137,7 +137,7 @@ def generate_refined_answer( decomp_questions.append(decomp_answer_result.question) if ( - decomp_answer_result.quality.lower().startswith("yes") + decomp_answer_result.verified_high_quality and len(decomp_answer_result.answer) > 0 and decomp_answer_result.answer != UNKNOWN_ANSWER ): diff --git a/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/refined_sub_question_creation.py b/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/refined_sub_question_creation.py index c4e0c6b43..b0a82a772 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/refined_sub_question_creation.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main__graph/nodes/refined_sub_question_creation.py @@ -67,11 +67,11 @@ def refined_sub_question_creation( initial_question_answers = state.decomp_answer_results addressed_question_list = [ - x.question for x in initial_question_answers if "yes" in x.quality.lower() + x.question for x in initial_question_answers if x.verified_high_quality ] failed_question_list = [ - x.question for x in initial_question_answers if "no" in x.quality.lower() + x.question for x in initial_question_answers if not x.verified_high_quality ] msg = [ diff --git a/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/edges.py b/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/edges.py index dc9771e09..a93b18fb9 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/edges.py +++ b/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/edges.py @@ -19,9 +19,8 @@ def parallel_retrieval_edge( agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) question = state.question if state.question else agent_a_config.search_request.query - query_expansions = ( - state.expanded_queries if state.expanded_queries else [] + [question] - ) + query_expansions = state.expanded_queries + [question] + return [ Send( "doc_retrieval", diff --git a/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/nodes/expand_queries.py b/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/nodes/expand_queries.py index 3e9dc0d2e..c34eeee7e 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/nodes/expand_queries.py +++ b/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/nodes/expand_queries.py @@ -33,11 +33,8 @@ def expand_queries( # Instead, we use the original question from the search request. agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) now_start = datetime.now() - question = ( - state.question - if hasattr(state, "question") - else agent_a_config.search_request.query - ) + question = state.question + llm = agent_a_config.fast_llm chat_session_id = agent_a_config.chat_session_id sub_question_id = state.sub_question_id diff --git a/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/nodes/verification_kickoff.py b/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/nodes/verification_kickoff.py index 10a3cb333..261a596a5 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/nodes/verification_kickoff.py +++ b/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/nodes/verification_kickoff.py @@ -1,4 +1,3 @@ -from typing import cast from typing import Literal from langchain_core.runnables.config import RunnableConfig @@ -11,7 +10,6 @@ from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.s from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import ( ExpandedRetrievalState, ) -from onyx.agents.agent_search.models import AgentSearchConfig def verification_kickoff( @@ -19,12 +17,8 @@ def verification_kickoff( config: RunnableConfig, ) -> Command[Literal["doc_verification"]]: documents = state.retrieved_documents - agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) - verification_question = ( - state.question - if hasattr(state, "question") - else agent_a_config.search_request.query - ) + verification_question = state.question + sub_question_id = state.sub_question_id return Command( update={}, diff --git a/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/operations.py b/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/operations.py index 86434a669..03cddd05a 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/operations.py +++ b/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/operations.py @@ -51,23 +51,15 @@ def calculate_sub_question_retrieval_stats( raw_chunk_stats_counts: dict[str, int] = defaultdict(int) raw_chunk_stats_scores: dict[str, float] = defaultdict(float) for doc_chunk_id, chunk_data in chunk_scores.items(): - if doc_chunk_id in verified_doc_chunk_ids: - raw_chunk_stats_counts["verified_count"] += 1 + valid_chunk_scores = [ + score for score in chunk_data["score"] if score is not None + ] + key = "verified" if doc_chunk_id in verified_doc_chunk_ids else "rejected" + raw_chunk_stats_counts[f"{key}_count"] += 1 - valid_chunk_scores = [ - score for score in chunk_data["score"] if score is not None - ] - raw_chunk_stats_scores["verified_scores"] += float( - np.mean(valid_chunk_scores) - ) - else: - raw_chunk_stats_counts["rejected_count"] += 1 - valid_chunk_scores = [ - score for score in chunk_data["score"] if score is not None - ] - raw_chunk_stats_scores["rejected_scores"] += float( - np.mean(valid_chunk_scores) - ) + raw_chunk_stats_scores[f"{key}_scores"] += float(np.mean(valid_chunk_scores)) + + if key == "rejected": dismissed_doc_chunk_ids.append(doc_chunk_id) if raw_chunk_stats_counts["verified_count"] == 0: diff --git a/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/states.py b/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/states.py index 2b90618b4..c2704bdec 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/states.py +++ b/backend/onyx/agents/agent_search/deep_search_a/util__expanded_retrieval__subgraph/states.py @@ -30,7 +30,7 @@ class ExpandedRetrievalInput(SubgraphCoreState): class QueryExpansionUpdate(BaseModel): - expanded_queries: list[str] = ["aaa", "bbb"] + expanded_queries: list[str] = [] log_messages: list[str] = [] diff --git a/backend/onyx/agents/agent_search/models.py b/backend/onyx/agents/agent_search/models.py index 517d66d74..e06211cd1 100644 --- a/backend/onyx/agents/agent_search/models.py +++ b/backend/onyx/agents/agent_search/models.py @@ -19,7 +19,7 @@ class AgentSearchConfig(BaseModel): Configuration for the Agent Search feature. """ - # The search request that was used to generate the Pro Search + # The search request that was used to generate the Agent Search search_request: SearchRequest primary_llm: LLM @@ -45,7 +45,7 @@ class AgentSearchConfig(BaseModel): # The message ID of the user message that triggered the Pro Search message_id: int | None = None - # Whether to persistence data for Agentic Search (turned off for testing) + # Whether to persist data for Agentic Search use_agentic_persistence: bool = True # The database session for Agentic Search @@ -89,6 +89,72 @@ class AgentSearchConfig(BaseModel): arbitrary_types_allowed = True +class GraphInputs(BaseModel): + """Input data required for the graph execution""" + + search_request: SearchRequest + prompt_builder: AnswerPromptBuilder + files: list[InMemoryChatFile] | None = None + structured_response_format: dict | None = None + + +class GraphTooling(BaseModel): + """Tools and LLMs available to the graph""" + + primary_llm: LLM + fast_llm: LLM + search_tool: SearchTool | None = None + tools: list[Tool] | None = None + force_use_tool: ForceUseTool + using_tool_calling_llm: bool = False + + +class GraphPersistence(BaseModel): + """Configuration for data persistence""" + + chat_session_id: UUID | None = None + message_id: int | None = None + use_agentic_persistence: bool = True + db_session: Session | None = None + + @model_validator(mode="after") + def validate_db_session(self) -> "GraphPersistence": + if self.use_agentic_persistence and self.db_session is None: + raise ValueError( + "db_session must be provided for pro search when using persistence" + ) + return self + + +class SearchBehaviorConfig(BaseModel): + """Configuration controlling search behavior""" + + use_agentic_search: bool = False + perform_initial_search_decomposition: bool = True + allow_refinement: bool = True + skip_gen_ai_answer_generation: bool = False + + +class GraphConfig(BaseModel): + """ + Main configuration class that combines all config components for Langgraph execution + """ + + inputs: GraphInputs + tooling: GraphTooling + persistence: GraphPersistence + behavior: SearchBehaviorConfig + + @model_validator(mode="after") + def validate_search_tool(self) -> "GraphConfig": + if self.behavior.use_agentic_search and self.tooling.search_tool is None: + raise ValueError("search_tool must be provided for agentic search") + return self + + class Config: + arbitrary_types_allowed = True + + class AgentDocumentCitations(BaseModel): document_id: str document_title: str diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/models.py b/backend/onyx/agents/agent_search/shared_graph_utils/models.py index c79577cbf..a399fa1df 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/models.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/models.py @@ -103,7 +103,7 @@ class QuestionAnswerResults(BaseModel): question: str question_id: str answer: str - quality: str + verified_high_quality: bool expanded_retrieval_results: list[QueryResult] documents: list[InferenceSection] context_documents: list[InferenceSection] diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index 49200279c..b14405c87 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -40,6 +40,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from onyx.configs.constants import DEFAULT_PERSONA_ID +from onyx.configs.constants import DISPATCH_SEP_CHAR from onyx.context.search.enums import LLMEvaluationType from onyx.context.search.models import InferenceSection from onyx.context.search.models import RetrievalDetails @@ -56,6 +57,8 @@ from onyx.tools.tool_implementations.search.search_tool import ( from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary from onyx.tools.tool_implementations.search.search_tool import SearchTool +BaseMessage_Content = str | list[str | dict[str, Any]] + def normalize_whitespace(text: str) -> str: """Normalize whitespace in text to single spaces and strip leading/trailing whitespace.""" @@ -289,14 +292,14 @@ def _dispatch_nonempty( def dispatch_separated( - token_itr: Iterator[BaseMessage], + tokens: Iterator[BaseMessage], dispatch_event: Callable[[str, int], None], - sep: str = "\n", -) -> list[str | list[str | dict[str, Any]]]: + sep: str = DISPATCH_SEP_CHAR, +) -> list[BaseMessage_Content]: num = 1 - streamed_tokens: list[str | list[str | dict[str, Any]]] = [""] - for message in token_itr: - content = cast(str, message.content) + streamed_tokens: list[BaseMessage_Content] = [] + for token in tokens: + content = cast(str, token.content) if sep in content: sub_question_parts = content.split(sep) _dispatch_nonempty(sub_question_parts[0], dispatch_event, num) diff --git a/backend/onyx/configs/constants.py b/backend/onyx/configs/constants.py index 605059081..cfaf0f81d 100644 --- a/backend/onyx/configs/constants.py +++ b/backend/onyx/configs/constants.py @@ -42,6 +42,7 @@ DEFAULT_CC_PAIR_ID = 1 BASIC_KEY = (-1, -1) AGENT_SEARCH_INITIAL_KEY = (0, 0) CANCEL_CHECK_INTERVAL = 20 +DISPATCH_SEP_CHAR = "\n" # Postgres connection constants for application_name POSTGRES_WEB_APP_NAME = "web" POSTGRES_INDEXER_APP_NAME = "indexer" diff --git a/backend/tests/regression/answer_quality/agent_test.py b/backend/tests/regression/answer_quality/agent_test.py index 25b19964e..f892d37f0 100644 --- a/backend/tests/regression/answer_quality/agent_test.py +++ b/backend/tests/regression/answer_quality/agent_test.py @@ -5,8 +5,10 @@ import os import yaml -from onyx.agents.agent_search.deep_search_a.main.graph_builder import main_graph_builder -from onyx.agents.agent_search.deep_search_a.main.states import MainInput +from onyx.agents.agent_search.deep_search_a.main__graph.graph_builder import ( + main_graph_builder, +) +from onyx.agents.agent_search.deep_search_a.main__graph.states import MainInput from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config from onyx.context.search.models import SearchRequest from onyx.db.engine import get_session_context_manager