default values of number of strings and other things

This commit is contained in:
joachim-danswer
2025-01-30 18:45:21 -08:00
committed by Evan Lohn
parent 385b344a43
commit 23ae4547ca
7 changed files with 55 additions and 43 deletions

View File

@ -244,23 +244,28 @@ def generate_initial_answer(
agent_base_end_time = datetime.now() agent_base_end_time = datetime.now()
if agent_base_end_time and state.agent_start_time:
duration__s = (agent_base_end_time - state.agent_start_time).total_seconds()
else:
duration__s = None
agent_base_metrics = AgentBaseMetrics( agent_base_metrics = AgentBaseMetrics(
num_verified_documents_total=len(relevant_docs), num_verified_documents_total=len(relevant_docs),
num_verified_documents_core=state.original_question_retrieval_stats.verified_count, num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores, verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
num_verified_documents_base=initial_agent_stats.sub_questions.get( num_verified_documents_base=initial_agent_stats.sub_questions.get(
"num_verified_documents", None "num_verified_documents"
), ),
verified_avg_score_base=initial_agent_stats.sub_questions.get( verified_avg_score_base=initial_agent_stats.sub_questions.get(
"verified_avg_score", None "verified_avg_score"
), ),
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get( base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", None "utilized_chunk_ratio"
), ),
support_boost_factor=initial_agent_stats.agent_effectiveness.get( support_boost_factor=initial_agent_stats.agent_effectiveness.get(
"support_ratio", None "support_ratio"
), ),
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(), duration__s=duration__s,
) )
logger.info( logger.info(

View File

@ -23,8 +23,8 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.create_refined_sub_questi
from onyx.agents.agent_search.deep_search_a.main.nodes.decide_refinement_need import ( from onyx.agents.agent_search.deep_search_a.main.nodes.decide_refinement_need import (
decide_refinement_need, decide_refinement_need,
) )
from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entity_term import ( from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entities_terms import (
extract_entity_term, extract_entities_terms,
) )
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import ( from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
generate_refined_answer, generate_refined_answer,
@ -116,7 +116,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph.add_node( graph.add_node(
node="extract_entity_term", node="extract_entity_term",
action=extract_entity_term, action=extract_entities_terms,
) )
graph.add_node( graph.add_node(
node="decide_refinement_need", node="decide_refinement_need",

View File

@ -25,7 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROM
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
def extract_entity_term( def extract_entities_terms(
state: MainState, config: RunnableConfig state: MainState, config: RunnableConfig
) -> EntityTermExtractionUpdate: ) -> EntityTermExtractionUpdate:
now_start = datetime.now() now_start = datetime.now()
@ -68,7 +68,11 @@ def extract_entity_term(
) )
cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content)) cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content))
parsed_response = json.loads(cleaned_response) try:
parsed_response = json.loads(cleaned_response)
except json.JSONDecodeError:
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
parsed_response = {}
entities = [] entities = []
relationships = [] relationships = []
@ -76,37 +80,40 @@ def extract_entity_term(
for entity in parsed_response.get("retrieved_entities_relationships", {}).get( for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
"entities", {} "entities", {}
): ):
entity_name = entity.get("entity_name", "") entity_name = entity.get("entity_name")
entity_type = entity.get("entity_type", "") entity_type = entity.get("entity_type")
entities.append(Entity(entity_name=entity_name, entity_type=entity_type)) if entity_name and entity_type:
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get( for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
"relationships", {} "relationships", {}
): ):
relationship_name = relationship.get("relationship_name", "") relationship_name = relationship.get("relationship_name")
relationship_type = relationship.get("relationship_type", "") relationship_type = relationship.get("relationship_type")
relationship_entities = relationship.get("relationship_entities", []) relationship_entities = relationship.get("relationship_entities")
relationships.append( if relationship_name and relationship_type and relationship_entities:
Relationship( relationships.append(
relationship_name=relationship_name, Relationship(
relationship_type=relationship_type, relationship_name=relationship_name,
relationship_entities=relationship_entities, relationship_type=relationship_type,
relationship_entities=relationship_entities,
)
) )
)
for term in parsed_response.get("retrieved_entities_relationships", {}).get( for term in parsed_response.get("retrieved_entities_relationships", {}).get(
"terms", {} "terms", {}
): ):
term_name = term.get("term_name", "") term_name = term.get("term_name")
term_type = term.get("term_type", "") term_type = term.get("term_type")
term_similar_to = term.get("term_similar_to", []) term_similar_to = term.get("term_similar_to")
terms.append( if term_name and term_type and term_similar_to:
Term( terms.append(
term_name=term_name, Term(
term_type=term_type, term_name=term_name,
term_similar_to=term_similar_to, term_type=term_type,
term_similar_to=term_similar_to,
)
) )
)
now_end = datetime.now() now_end = datetime.now()

View File

@ -28,7 +28,7 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
agent_end_time = agent_refined_end_time or agent_base_end_time agent_end_time = agent_refined_end_time or agent_base_end_time
agent_base_duration = None agent_base_duration = None
if agent_base_end_time: if agent_base_end_time and agent_start_time:
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds() agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
agent_refined_duration = None agent_refined_duration = None
@ -38,7 +38,7 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
).total_seconds() ).total_seconds()
agent_full_duration = None agent_full_duration = None
if agent_end_time: if agent_end_time and agent_start_time:
agent_full_duration = (agent_end_time - agent_start_time).total_seconds() agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
agent_type = "refined" if agent_refined_duration else "base" agent_type = "refined" if agent_refined_duration else "base"

View File

@ -56,14 +56,14 @@ class RefinedAgentEndStats(BaseModel):
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats): class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats):
agent_start_time: datetime = datetime.now() agent_start_time: datetime | None = None
previous_history: str = "" previous_history: str | None = None
initial_decomp_questions: list[str] = [] initial_decomp_questions: list[str] = []
class ExploratorySearchUpdate(LoggerUpdate): class ExploratorySearchUpdate(LoggerUpdate):
exploratory_search_results: list[InferenceSection] = [] exploratory_search_results: list[InferenceSection] = []
previous_history_summary: str = "" previous_history_summary: str | None = None
class AnswerComparison(LoggerUpdate): class AnswerComparison(LoggerUpdate):
@ -71,24 +71,24 @@ class AnswerComparison(LoggerUpdate):
class RoutingDecision(LoggerUpdate): class RoutingDecision(LoggerUpdate):
routing_decision: str = "" routing_decision: str | None = None
# Not used in current graph # Not used in current graph
class InitialAnswerBASEUpdate(BaseModel): class InitialAnswerBASEUpdate(BaseModel):
initial_base_answer: str = "" initial_base_answer: str
class InitialAnswerUpdate(LoggerUpdate): class InitialAnswerUpdate(LoggerUpdate):
initial_answer: str = "" initial_answer: str
initial_agent_stats: InitialAgentResultStats | None = None initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = [] generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None agent_base_end_time: datetime
agent_base_metrics: AgentBaseMetrics | None = None agent_base_metrics: AgentBaseMetrics | None = None
class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate): class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
refined_answer: str = "" refined_answer: str
refined_agent_stats: RefinedAgentStats | None = None refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False refined_answer_quality: bool = False

View File

@ -73,7 +73,7 @@ def calculate_sub_question_retrieval_stats(
raw_chunk_stats_counts["verified_count"] raw_chunk_stats_counts["verified_count"]
) )
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None) rejected_scores = raw_chunk_stats_scores.get("rejected_scores")
if rejected_scores is not None: if rejected_scores is not None:
rejected_avg_scores = rejected_scores / float( rejected_avg_scores = rejected_scores / float(
raw_chunk_stats_counts["rejected_count"] raw_chunk_stats_counts["rejected_count"]

View File

@ -953,7 +953,7 @@ def log_agent_metrics(
user_id: UUID | None, user_id: UUID | None,
persona_id: int | None, # Can be none if temporary persona is used persona_id: int | None, # Can be none if temporary persona is used
agent_type: str, agent_type: str,
start_time: datetime, start_time: datetime | None,
agent_metrics: CombinedAgentMetrics, agent_metrics: CombinedAgentMetrics,
) -> AgentSearchMetrics: ) -> AgentSearchMetrics:
agent_timings = agent_metrics.timings agent_timings = agent_metrics.timings