WIP PR comments

This commit is contained in:
Evan Lohn 2025-01-30 09:46:53 -08:00
parent 1dbf561db0
commit bd3b1943c4
16 changed files with 111 additions and 64 deletions

View File

@ -12,7 +12,6 @@ from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subg
QACheckUpdate,
)
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
@ -25,7 +24,7 @@ def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckU
if state.answer == UNKNOWN_ANSWER:
now_end = datetime.now()
return QACheckUpdate(
answer_quality=SUB_CHECK_NO,
answer_quality=False,
log_messages=[
f"{now_start} -- Answer check SQ-{level}-{question_num} - unknown answer, Time taken: {now_end - now_start}"
],
@ -47,11 +46,12 @@ def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckU
)
)
quality_str = merge_message_runs(response, chunk_separator="")[0].content
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
answer_quality = "yes" in quality_str.lower()
now_end = datetime.now()
return QACheckUpdate(
answer_quality=quality_str,
answer_quality=answer_quality,
log_messages=[
f"""{now_start} -- Answer check SQ-{level}-{question_num} - Answer quality: {quality_str},
Time taken: {now_end - now_start}"""

View File

@ -15,9 +15,7 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
QuestionAnswerResults(
question=state.question,
question_id=state.question_id,
quality=state.answer_quality
if hasattr(state, "answer_quality")
else "No",
verified_high_quality=state.answer_quality,
answer=state.answer,
expanded_retrieval_results=state.expanded_retrieval_results,
documents=state.documents,

View File

@ -17,7 +17,7 @@ from onyx.context.search.models import InferenceSection
## Update States
class QACheckUpdate(BaseModel):
answer_quality: str = ""
answer_quality: bool = False
log_messages: list[str] = []

View File

@ -45,7 +45,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
@ -136,9 +135,8 @@ def generate_initial_answer(
for decomp_answer_result in decomp_answer_results:
decomp_questions.append(decomp_answer_result.question)
_, question_nr = parse_question_id(decomp_answer_result.question_id)
if (
decomp_answer_result.quality.lower().startswith("yes")
decomp_answer_result.verified_high_quality
and len(decomp_answer_result.answer) > 0
and decomp_answer_result.answer != UNKNOWN_ANSWER
):
@ -151,15 +149,12 @@ def generate_initial_answer(
)
sub_question_nr += 1
if len(good_qa_list) > 0:
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
else:
sub_question_answer_str = ""
# Determine which base prompt to use given the sub-question information
if len(good_qa_list) > 0:
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
base_prompt = INITIAL_RAG_PROMPT
else:
sub_question_answer_str = ""
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
model = agent_a_config.fast_llm
@ -182,7 +177,7 @@ def generate_initial_answer(
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=format_docs(relevant_docs),
relevant_docs=doc_context,
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
history=prompt_enrichment_components.history,
date_prompt=prompt_enrichment_components.date_str,

View File

@ -137,7 +137,7 @@ def generate_refined_answer(
decomp_questions.append(decomp_answer_result.question)
if (
decomp_answer_result.quality.lower().startswith("yes")
decomp_answer_result.verified_high_quality
and len(decomp_answer_result.answer) > 0
and decomp_answer_result.answer != UNKNOWN_ANSWER
):

View File

@ -67,11 +67,11 @@ def refined_sub_question_creation(
initial_question_answers = state.decomp_answer_results
addressed_question_list = [
x.question for x in initial_question_answers if "yes" in x.quality.lower()
x.question for x in initial_question_answers if x.verified_high_quality
]
failed_question_list = [
x.question for x in initial_question_answers if "no" in x.quality.lower()
x.question for x in initial_question_answers if not x.verified_high_quality
]
msg = [

View File

@ -19,9 +19,8 @@ def parallel_retrieval_edge(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = state.question if state.question else agent_a_config.search_request.query
query_expansions = (
state.expanded_queries if state.expanded_queries else [] + [question]
)
query_expansions = state.expanded_queries + [question]
return [
Send(
"doc_retrieval",

View File

@ -33,11 +33,8 @@ def expand_queries(
# Instead, we use the original question from the search request.
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
now_start = datetime.now()
question = (
state.question
if hasattr(state, "question")
else agent_a_config.search_request.query
)
question = state.question
llm = agent_a_config.fast_llm
chat_session_id = agent_a_config.chat_session_id
sub_question_id = state.sub_question_id

View File

@ -1,4 +1,3 @@
from typing import cast
from typing import Literal
from langchain_core.runnables.config import RunnableConfig
@ -11,7 +10,6 @@ from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.s
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.models import AgentSearchConfig
def verification_kickoff(
@ -19,12 +17,8 @@ def verification_kickoff(
config: RunnableConfig,
) -> Command[Literal["doc_verification"]]:
documents = state.retrieved_documents
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
verification_question = (
state.question
if hasattr(state, "question")
else agent_a_config.search_request.query
)
verification_question = state.question
sub_question_id = state.sub_question_id
return Command(
update={},

View File

@ -51,23 +51,15 @@ def calculate_sub_question_retrieval_stats(
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
for doc_chunk_id, chunk_data in chunk_scores.items():
if doc_chunk_id in verified_doc_chunk_ids:
raw_chunk_stats_counts["verified_count"] += 1
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
key = "verified" if doc_chunk_id in verified_doc_chunk_ids else "rejected"
raw_chunk_stats_counts[f"{key}_count"] += 1
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
raw_chunk_stats_scores["verified_scores"] += float(
np.mean(valid_chunk_scores)
)
else:
raw_chunk_stats_counts["rejected_count"] += 1
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
raw_chunk_stats_scores["rejected_scores"] += float(
np.mean(valid_chunk_scores)
)
raw_chunk_stats_scores[f"{key}_scores"] += float(np.mean(valid_chunk_scores))
if key == "rejected":
dismissed_doc_chunk_ids.append(doc_chunk_id)
if raw_chunk_stats_counts["verified_count"] == 0:

View File

@ -30,7 +30,7 @@ class ExpandedRetrievalInput(SubgraphCoreState):
class QueryExpansionUpdate(BaseModel):
expanded_queries: list[str] = ["aaa", "bbb"]
expanded_queries: list[str] = []
log_messages: list[str] = []

View File

@ -19,7 +19,7 @@ class AgentSearchConfig(BaseModel):
Configuration for the Agent Search feature.
"""
# The search request that was used to generate the Pro Search
# The search request that was used to generate the Agent Search
search_request: SearchRequest
primary_llm: LLM
@ -45,7 +45,7 @@ class AgentSearchConfig(BaseModel):
# The message ID of the user message that triggered the Pro Search
message_id: int | None = None
# Whether to persistence data for Agentic Search (turned off for testing)
# Whether to persist data for Agentic Search
use_agentic_persistence: bool = True
# The database session for Agentic Search
@ -89,6 +89,72 @@ class AgentSearchConfig(BaseModel):
arbitrary_types_allowed = True
class GraphInputs(BaseModel):
"""Input data required for the graph execution"""
search_request: SearchRequest
prompt_builder: AnswerPromptBuilder
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
class GraphTooling(BaseModel):
"""Tools and LLMs available to the graph"""
primary_llm: LLM
fast_llm: LLM
search_tool: SearchTool | None = None
tools: list[Tool] | None = None
force_use_tool: ForceUseTool
using_tool_calling_llm: bool = False
class GraphPersistence(BaseModel):
"""Configuration for data persistence"""
chat_session_id: UUID | None = None
message_id: int | None = None
use_agentic_persistence: bool = True
db_session: Session | None = None
@model_validator(mode="after")
def validate_db_session(self) -> "GraphPersistence":
if self.use_agentic_persistence and self.db_session is None:
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
return self
class SearchBehaviorConfig(BaseModel):
"""Configuration controlling search behavior"""
use_agentic_search: bool = False
perform_initial_search_decomposition: bool = True
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
class GraphConfig(BaseModel):
"""
Main configuration class that combines all config components for Langgraph execution
"""
inputs: GraphInputs
tooling: GraphTooling
persistence: GraphPersistence
behavior: SearchBehaviorConfig
@model_validator(mode="after")
def validate_search_tool(self) -> "GraphConfig":
if self.behavior.use_agentic_search and self.tooling.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
return self
class Config:
arbitrary_types_allowed = True
class AgentDocumentCitations(BaseModel):
document_id: str
document_title: str

View File

@ -103,7 +103,7 @@ class QuestionAnswerResults(BaseModel):
question: str
question_id: str
answer: str
quality: str
verified_high_quality: bool
expanded_retrieval_results: list[QueryResult]
documents: list[InferenceSection]
context_documents: list[InferenceSection]

View File

@ -40,6 +40,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import DISPATCH_SEP_CHAR
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
@ -56,6 +57,8 @@ from onyx.tools.tool_implementations.search.search_tool import (
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
BaseMessage_Content = str | list[str | dict[str, Any]]
def normalize_whitespace(text: str) -> str:
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
@ -289,14 +292,14 @@ def _dispatch_nonempty(
def dispatch_separated(
token_itr: Iterator[BaseMessage],
tokens: Iterator[BaseMessage],
dispatch_event: Callable[[str, int], None],
sep: str = "\n",
) -> list[str | list[str | dict[str, Any]]]:
sep: str = DISPATCH_SEP_CHAR,
) -> list[BaseMessage_Content]:
num = 1
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
for message in token_itr:
content = cast(str, message.content)
streamed_tokens: list[BaseMessage_Content] = []
for token in tokens:
content = cast(str, token.content)
if sep in content:
sub_question_parts = content.split(sep)
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)

View File

@ -42,6 +42,7 @@ DEFAULT_CC_PAIR_ID = 1
BASIC_KEY = (-1, -1)
AGENT_SEARCH_INITIAL_KEY = (0, 0)
CANCEL_CHECK_INTERVAL = 20
DISPATCH_SEP_CHAR = "\n"
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"

View File

@ -5,8 +5,10 @@ import os
import yaml
from onyx.agents.agent_search.deep_search_a.main.graph_builder import main_graph_builder
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
from onyx.agents.agent_search.deep_search_a.main__graph.graph_builder import (
main_graph_builder,
)
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainInput
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager