From 23ae4547cad943a48d0aa9677fdebb42dff764de Mon Sep 17 00:00:00 2001 From: joachim-danswer Date: Thu, 30 Jan 2025 18:45:21 -0800 Subject: [PATCH] default values of number of strings and other things --- .../nodes/generate_initial_answer.py | 15 ++++-- .../deep_search_a/main/graph_builder.py | 6 +-- ...tity_term.py => extract_entities_terms.py} | 53 +++++++++++-------- .../main/nodes/persist_agent_results.py | 4 +- .../agent_search/deep_search_a/main/states.py | 16 +++--- .../shared/expanded_retrieval/operations.py | 2 +- backend/onyx/db/chat.py | 2 +- 7 files changed, 55 insertions(+), 43 deletions(-) rename backend/onyx/agents/agent_search/deep_search_a/main/nodes/{extract_entity_term.py => extract_entities_terms.py} (74%) diff --git a/backend/onyx/agents/agent_search/deep_search_a/initial/generate_initial_answer/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search_a/initial/generate_initial_answer/nodes/generate_initial_answer.py index 324e8c183..e248ea1b6 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/initial/generate_initial_answer/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/initial/generate_initial_answer/nodes/generate_initial_answer.py @@ -244,23 +244,28 @@ def generate_initial_answer( agent_base_end_time = datetime.now() + if agent_base_end_time and state.agent_start_time: + duration__s = (agent_base_end_time - state.agent_start_time).total_seconds() + else: + duration__s = None + agent_base_metrics = AgentBaseMetrics( num_verified_documents_total=len(relevant_docs), num_verified_documents_core=state.original_question_retrieval_stats.verified_count, verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores, num_verified_documents_base=initial_agent_stats.sub_questions.get( - "num_verified_documents", None + "num_verified_documents" ), verified_avg_score_base=initial_agent_stats.sub_questions.get( - "verified_avg_score", None + "verified_avg_score" ), base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get( - "utilized_chunk_ratio", None + "utilized_chunk_ratio" ), support_boost_factor=initial_agent_stats.agent_effectiveness.get( - "support_ratio", None + "support_ratio" ), - duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(), + duration__s=duration__s, ) logger.info( diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py b/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py index a121025bf..08a776075 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/graph_builder.py @@ -23,8 +23,8 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.create_refined_sub_questi from onyx.agents.agent_search.deep_search_a.main.nodes.decide_refinement_need import ( decide_refinement_need, ) -from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entity_term import ( - extract_entity_term, +from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entities_terms import ( + extract_entities_terms, ) from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import ( generate_refined_answer, @@ -116,7 +116,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph: graph.add_node( node="extract_entity_term", - action=extract_entity_term, + action=extract_entities_terms, ) graph.add_node( node="decide_refinement_need", diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/extract_entity_term.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/extract_entities_terms.py similarity index 74% rename from backend/onyx/agents/agent_search/deep_search_a/main/nodes/extract_entity_term.py rename to backend/onyx/agents/agent_search/deep_search_a/main/nodes/extract_entities_terms.py index 9748677a4..b241e5fc3 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/extract_entity_term.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/extract_entities_terms.py @@ -25,7 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROM from onyx.agents.agent_search.shared_graph_utils.utils import format_docs -def extract_entity_term( +def extract_entities_terms( state: MainState, config: RunnableConfig ) -> EntityTermExtractionUpdate: now_start = datetime.now() @@ -68,7 +68,11 @@ def extract_entity_term( ) cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content)) - parsed_response = json.loads(cleaned_response) + try: + parsed_response = json.loads(cleaned_response) + except json.JSONDecodeError: + logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction") + parsed_response = {} entities = [] relationships = [] @@ -76,37 +80,40 @@ def extract_entity_term( for entity in parsed_response.get("retrieved_entities_relationships", {}).get( "entities", {} ): - entity_name = entity.get("entity_name", "") - entity_type = entity.get("entity_type", "") - entities.append(Entity(entity_name=entity_name, entity_type=entity_type)) + entity_name = entity.get("entity_name") + entity_type = entity.get("entity_type") + if entity_name and entity_type: + entities.append(Entity(entity_name=entity_name, entity_type=entity_type)) for relationship in parsed_response.get("retrieved_entities_relationships", {}).get( "relationships", {} ): - relationship_name = relationship.get("relationship_name", "") - relationship_type = relationship.get("relationship_type", "") - relationship_entities = relationship.get("relationship_entities", []) - relationships.append( - Relationship( - relationship_name=relationship_name, - relationship_type=relationship_type, - relationship_entities=relationship_entities, + relationship_name = relationship.get("relationship_name") + relationship_type = relationship.get("relationship_type") + relationship_entities = relationship.get("relationship_entities") + if relationship_name and relationship_type and relationship_entities: + relationships.append( + Relationship( + relationship_name=relationship_name, + relationship_type=relationship_type, + relationship_entities=relationship_entities, + ) ) - ) for term in parsed_response.get("retrieved_entities_relationships", {}).get( "terms", {} ): - term_name = term.get("term_name", "") - term_type = term.get("term_type", "") - term_similar_to = term.get("term_similar_to", []) - terms.append( - Term( - term_name=term_name, - term_type=term_type, - term_similar_to=term_similar_to, + term_name = term.get("term_name") + term_type = term.get("term_type") + term_similar_to = term.get("term_similar_to") + if term_name and term_type and term_similar_to: + terms.append( + Term( + term_name=term_name, + term_type=term_type, + term_similar_to=term_similar_to, + ) ) - ) now_end = datetime.now() diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/persist_agent_results.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/persist_agent_results.py index 8bf815157..2f95c5609 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/persist_agent_results.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/persist_agent_results.py @@ -28,7 +28,7 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu agent_end_time = agent_refined_end_time or agent_base_end_time agent_base_duration = None - if agent_base_end_time: + if agent_base_end_time and agent_start_time: agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds() agent_refined_duration = None @@ -38,7 +38,7 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu ).total_seconds() agent_full_duration = None - if agent_end_time: + if agent_end_time and agent_start_time: agent_full_duration = (agent_end_time - agent_start_time).total_seconds() agent_type = "refined" if agent_refined_duration else "base" diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/states.py b/backend/onyx/agents/agent_search/deep_search_a/main/states.py index bb6e1283f..7729c5637 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/states.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/states.py @@ -56,14 +56,14 @@ class RefinedAgentEndStats(BaseModel): class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats): - agent_start_time: datetime = datetime.now() - previous_history: str = "" + agent_start_time: datetime | None = None + previous_history: str | None = None initial_decomp_questions: list[str] = [] class ExploratorySearchUpdate(LoggerUpdate): exploratory_search_results: list[InferenceSection] = [] - previous_history_summary: str = "" + previous_history_summary: str | None = None class AnswerComparison(LoggerUpdate): @@ -71,24 +71,24 @@ class AnswerComparison(LoggerUpdate): class RoutingDecision(LoggerUpdate): - routing_decision: str = "" + routing_decision: str | None = None # Not used in current graph class InitialAnswerBASEUpdate(BaseModel): - initial_base_answer: str = "" + initial_base_answer: str class InitialAnswerUpdate(LoggerUpdate): - initial_answer: str = "" + initial_answer: str initial_agent_stats: InitialAgentResultStats | None = None generated_sub_questions: list[str] = [] - agent_base_end_time: datetime | None = None + agent_base_end_time: datetime agent_base_metrics: AgentBaseMetrics | None = None class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate): - refined_answer: str = "" + refined_answer: str refined_agent_stats: RefinedAgentStats | None = None refined_answer_quality: bool = False diff --git a/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/operations.py b/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/operations.py index a6ac9bfa5..79471ee6d 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/operations.py +++ b/backend/onyx/agents/agent_search/deep_search_a/shared/expanded_retrieval/operations.py @@ -73,7 +73,7 @@ def calculate_sub_question_retrieval_stats( raw_chunk_stats_counts["verified_count"] ) - rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None) + rejected_scores = raw_chunk_stats_scores.get("rejected_scores") if rejected_scores is not None: rejected_avg_scores = rejected_scores / float( raw_chunk_stats_counts["rejected_count"] diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index f586b294c..2bf1e0062 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -953,7 +953,7 @@ def log_agent_metrics( user_id: UUID | None, persona_id: int | None, # Can be none if temporary persona is used agent_type: str, - start_time: datetime, + start_time: datetime | None, agent_metrics: CombinedAgentMetrics, ) -> AgentSearchMetrics: agent_timings = agent_metrics.timings