mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-28 12:39:54 +02:00
WIP PR comments
This commit is contained in:
parent
1dbf561db0
commit
bd3b1943c4
@ -12,7 +12,6 @@ from onyx.agents.agent_search.deep_search_a.initial__individual_sub_answer__subg
|
||||
QACheckUpdate,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_NO
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import UNKNOWN_ANSWER
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
@ -25,7 +24,7 @@ def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckU
|
||||
if state.answer == UNKNOWN_ANSWER:
|
||||
now_end = datetime.now()
|
||||
return QACheckUpdate(
|
||||
answer_quality=SUB_CHECK_NO,
|
||||
answer_quality=False,
|
||||
log_messages=[
|
||||
f"{now_start} -- Answer check SQ-{level}-{question_num} - unknown answer, Time taken: {now_end - now_start}"
|
||||
],
|
||||
@ -47,11 +46,12 @@ def answer_check(state: AnswerQuestionState, config: RunnableConfig) -> QACheckU
|
||||
)
|
||||
)
|
||||
|
||||
quality_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
answer_quality = "yes" in quality_str.lower()
|
||||
|
||||
now_end = datetime.now()
|
||||
return QACheckUpdate(
|
||||
answer_quality=quality_str,
|
||||
answer_quality=answer_quality,
|
||||
log_messages=[
|
||||
f"""{now_start} -- Answer check SQ-{level}-{question_num} - Answer quality: {quality_str},
|
||||
Time taken: {now_end - now_start}"""
|
||||
|
@ -15,9 +15,7 @@ def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
QuestionAnswerResults(
|
||||
question=state.question,
|
||||
question_id=state.question_id,
|
||||
quality=state.answer_quality
|
||||
if hasattr(state, "answer_quality")
|
||||
else "No",
|
||||
verified_high_quality=state.answer_quality,
|
||||
answer=state.answer,
|
||||
expanded_retrieval_results=state.expanded_retrieval_results,
|
||||
documents=state.documents,
|
||||
|
@ -17,7 +17,7 @@ from onyx.context.search.models import InferenceSection
|
||||
|
||||
## Update States
|
||||
class QACheckUpdate(BaseModel):
|
||||
answer_quality: str = ""
|
||||
answer_quality: bool = False
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
|
@ -45,7 +45,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
@ -136,9 +135,8 @@ def generate_initial_answer(
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
_, question_nr = parse_question_id(decomp_answer_result.question_id)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
decomp_answer_result.verified_high_quality
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
@ -151,15 +149,12 @@ def generate_initial_answer(
|
||||
)
|
||||
sub_question_nr += 1
|
||||
|
||||
if len(good_qa_list) > 0:
|
||||
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
|
||||
else:
|
||||
sub_question_answer_str = ""
|
||||
|
||||
# Determine which base prompt to use given the sub-question information
|
||||
if len(good_qa_list) > 0:
|
||||
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
|
||||
base_prompt = INITIAL_RAG_PROMPT
|
||||
else:
|
||||
sub_question_answer_str = ""
|
||||
base_prompt = INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS
|
||||
|
||||
model = agent_a_config.fast_llm
|
||||
@ -182,7 +177,7 @@ def generate_initial_answer(
|
||||
answered_sub_questions=remove_document_citations(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=format_docs(relevant_docs),
|
||||
relevant_docs=doc_context,
|
||||
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
|
||||
history=prompt_enrichment_components.history,
|
||||
date_prompt=prompt_enrichment_components.date_str,
|
||||
|
@ -137,7 +137,7 @@ def generate_refined_answer(
|
||||
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
decomp_answer_result.verified_high_quality
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != UNKNOWN_ANSWER
|
||||
):
|
||||
|
@ -67,11 +67,11 @@ def refined_sub_question_creation(
|
||||
initial_question_answers = state.decomp_answer_results
|
||||
|
||||
addressed_question_list = [
|
||||
x.question for x in initial_question_answers if "yes" in x.quality.lower()
|
||||
x.question for x in initial_question_answers if x.verified_high_quality
|
||||
]
|
||||
|
||||
failed_question_list = [
|
||||
x.question for x in initial_question_answers if "no" in x.quality.lower()
|
||||
x.question for x in initial_question_answers if not x.verified_high_quality
|
||||
]
|
||||
|
||||
msg = [
|
||||
|
@ -19,9 +19,8 @@ def parallel_retrieval_edge(
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = state.question if state.question else agent_a_config.search_request.query
|
||||
|
||||
query_expansions = (
|
||||
state.expanded_queries if state.expanded_queries else [] + [question]
|
||||
)
|
||||
query_expansions = state.expanded_queries + [question]
|
||||
|
||||
return [
|
||||
Send(
|
||||
"doc_retrieval",
|
||||
|
@ -33,11 +33,8 @@ def expand_queries(
|
||||
# Instead, we use the original question from the search request.
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
now_start = datetime.now()
|
||||
question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
question = state.question
|
||||
|
||||
llm = agent_a_config.fast_llm
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
sub_question_id = state.sub_question_id
|
||||
|
@ -1,4 +1,3 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
@ -11,7 +10,6 @@ from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.s
|
||||
from onyx.agents.agent_search.deep_search_a.util__expanded_retrieval__subgraph.states import (
|
||||
ExpandedRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
|
||||
|
||||
def verification_kickoff(
|
||||
@ -19,12 +17,8 @@ def verification_kickoff(
|
||||
config: RunnableConfig,
|
||||
) -> Command[Literal["doc_verification"]]:
|
||||
documents = state.retrieved_documents
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
verification_question = (
|
||||
state.question
|
||||
if hasattr(state, "question")
|
||||
else agent_a_config.search_request.query
|
||||
)
|
||||
verification_question = state.question
|
||||
|
||||
sub_question_id = state.sub_question_id
|
||||
return Command(
|
||||
update={},
|
||||
|
@ -51,23 +51,15 @@ def calculate_sub_question_retrieval_stats(
|
||||
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
|
||||
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
|
||||
for doc_chunk_id, chunk_data in chunk_scores.items():
|
||||
if doc_chunk_id in verified_doc_chunk_ids:
|
||||
raw_chunk_stats_counts["verified_count"] += 1
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
key = "verified" if doc_chunk_id in verified_doc_chunk_ids else "rejected"
|
||||
raw_chunk_stats_counts[f"{key}_count"] += 1
|
||||
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["verified_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
else:
|
||||
raw_chunk_stats_counts["rejected_count"] += 1
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["rejected_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
raw_chunk_stats_scores[f"{key}_scores"] += float(np.mean(valid_chunk_scores))
|
||||
|
||||
if key == "rejected":
|
||||
dismissed_doc_chunk_ids.append(doc_chunk_id)
|
||||
|
||||
if raw_chunk_stats_counts["verified_count"] == 0:
|
||||
|
@ -30,7 +30,7 @@ class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
|
||||
|
||||
class QueryExpansionUpdate(BaseModel):
|
||||
expanded_queries: list[str] = ["aaa", "bbb"]
|
||||
expanded_queries: list[str] = []
|
||||
log_messages: list[str] = []
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ class AgentSearchConfig(BaseModel):
|
||||
Configuration for the Agent Search feature.
|
||||
"""
|
||||
|
||||
# The search request that was used to generate the Pro Search
|
||||
# The search request that was used to generate the Agent Search
|
||||
search_request: SearchRequest
|
||||
|
||||
primary_llm: LLM
|
||||
@ -45,7 +45,7 @@ class AgentSearchConfig(BaseModel):
|
||||
# The message ID of the user message that triggered the Pro Search
|
||||
message_id: int | None = None
|
||||
|
||||
# Whether to persistence data for Agentic Search (turned off for testing)
|
||||
# Whether to persist data for Agentic Search
|
||||
use_agentic_persistence: bool = True
|
||||
|
||||
# The database session for Agentic Search
|
||||
@ -89,6 +89,72 @@ class AgentSearchConfig(BaseModel):
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class GraphInputs(BaseModel):
|
||||
"""Input data required for the graph execution"""
|
||||
|
||||
search_request: SearchRequest
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
|
||||
class GraphTooling(BaseModel):
|
||||
"""Tools and LLMs available to the graph"""
|
||||
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
search_tool: SearchTool | None = None
|
||||
tools: list[Tool] | None = None
|
||||
force_use_tool: ForceUseTool
|
||||
using_tool_calling_llm: bool = False
|
||||
|
||||
|
||||
class GraphPersistence(BaseModel):
|
||||
"""Configuration for data persistence"""
|
||||
|
||||
chat_session_id: UUID | None = None
|
||||
message_id: int | None = None
|
||||
use_agentic_persistence: bool = True
|
||||
db_session: Session | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_db_session(self) -> "GraphPersistence":
|
||||
if self.use_agentic_persistence and self.db_session is None:
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class SearchBehaviorConfig(BaseModel):
|
||||
"""Configuration controlling search behavior"""
|
||||
|
||||
use_agentic_search: bool = False
|
||||
perform_initial_search_decomposition: bool = True
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
"""
|
||||
Main configuration class that combines all config components for Langgraph execution
|
||||
"""
|
||||
|
||||
inputs: GraphInputs
|
||||
tooling: GraphTooling
|
||||
persistence: GraphPersistence
|
||||
behavior: SearchBehaviorConfig
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_tool(self) -> "GraphConfig":
|
||||
if self.behavior.use_agentic_search and self.tooling.search_tool is None:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentDocumentCitations(BaseModel):
|
||||
document_id: str
|
||||
document_title: str
|
||||
|
@ -103,7 +103,7 @@ class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
question_id: str
|
||||
answer: str
|
||||
quality: str
|
||||
verified_high_quality: bool
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
context_documents: list[InferenceSection]
|
||||
|
@ -40,6 +40,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import DISPATCH_SEP_CHAR
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
@ -56,6 +57,8 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
||||
@ -289,14 +292,14 @@ def _dispatch_nonempty(
|
||||
|
||||
|
||||
def dispatch_separated(
|
||||
token_itr: Iterator[BaseMessage],
|
||||
tokens: Iterator[BaseMessage],
|
||||
dispatch_event: Callable[[str, int], None],
|
||||
sep: str = "\n",
|
||||
) -> list[str | list[str | dict[str, Any]]]:
|
||||
sep: str = DISPATCH_SEP_CHAR,
|
||||
) -> list[BaseMessage_Content]:
|
||||
num = 1
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
for message in token_itr:
|
||||
content = cast(str, message.content)
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
for token in tokens:
|
||||
content = cast(str, token.content)
|
||||
if sep in content:
|
||||
sub_question_parts = content.split(sep)
|
||||
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
|
||||
|
@ -42,6 +42,7 @@ DEFAULT_CC_PAIR_ID = 1
|
||||
BASIC_KEY = (-1, -1)
|
||||
AGENT_SEARCH_INITIAL_KEY = (0, 0)
|
||||
CANCEL_CHECK_INTERVAL = 20
|
||||
DISPATCH_SEP_CHAR = "\n"
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
|
@ -5,8 +5,10 @@ import os
|
||||
|
||||
import yaml
|
||||
|
||||
from onyx.agents.agent_search.deep_search_a.main.graph_builder import main_graph_builder
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.graph_builder import (
|
||||
main_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main__graph.states import MainInput
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
|
Loading…
x
Reference in New Issue
Block a user