default values of number of strings and other things

This commit is contained in:
joachim-danswer 2025-01-30 18:45:21 -08:00 committed by Evan Lohn
parent 385b344a43
commit 23ae4547ca
7 changed files with 55 additions and 43 deletions

View File

@ -244,23 +244,28 @@ def generate_initial_answer(
agent_base_end_time = datetime.now()
if agent_base_end_time and state.agent_start_time:
duration__s = (agent_base_end_time - state.agent_start_time).total_seconds()
else:
duration__s = None
agent_base_metrics = AgentBaseMetrics(
num_verified_documents_total=len(relevant_docs),
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
num_verified_documents_base=initial_agent_stats.sub_questions.get(
"num_verified_documents", None
"num_verified_documents"
),
verified_avg_score_base=initial_agent_stats.sub_questions.get(
"verified_avg_score", None
"verified_avg_score"
),
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", None
"utilized_chunk_ratio"
),
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
"support_ratio", None
"support_ratio"
),
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(),
duration__s=duration__s,
)
logger.info(

View File

@ -23,8 +23,8 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.create_refined_sub_questi
from onyx.agents.agent_search.deep_search_a.main.nodes.decide_refinement_need import (
decide_refinement_need,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entity_term import (
extract_entity_term,
from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entities_terms import (
extract_entities_terms,
)
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
generate_refined_answer,
@ -116,7 +116,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph.add_node(
node="extract_entity_term",
action=extract_entity_term,
action=extract_entities_terms,
)
graph.add_node(
node="decide_refinement_need",

View File

@ -25,7 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROM
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
def extract_entity_term(
def extract_entities_terms(
state: MainState, config: RunnableConfig
) -> EntityTermExtractionUpdate:
now_start = datetime.now()
@ -68,7 +68,11 @@ def extract_entity_term(
)
cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content))
parsed_response = json.loads(cleaned_response)
try:
parsed_response = json.loads(cleaned_response)
except json.JSONDecodeError:
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
parsed_response = {}
entities = []
relationships = []
@ -76,37 +80,40 @@ def extract_entity_term(
for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
"entities", {}
):
entity_name = entity.get("entity_name", "")
entity_type = entity.get("entity_type", "")
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
entity_name = entity.get("entity_name")
entity_type = entity.get("entity_type")
if entity_name and entity_type:
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
"relationships", {}
):
relationship_name = relationship.get("relationship_name", "")
relationship_type = relationship.get("relationship_type", "")
relationship_entities = relationship.get("relationship_entities", [])
relationships.append(
Relationship(
relationship_name=relationship_name,
relationship_type=relationship_type,
relationship_entities=relationship_entities,
relationship_name = relationship.get("relationship_name")
relationship_type = relationship.get("relationship_type")
relationship_entities = relationship.get("relationship_entities")
if relationship_name and relationship_type and relationship_entities:
relationships.append(
Relationship(
relationship_name=relationship_name,
relationship_type=relationship_type,
relationship_entities=relationship_entities,
)
)
)
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
"terms", {}
):
term_name = term.get("term_name", "")
term_type = term.get("term_type", "")
term_similar_to = term.get("term_similar_to", [])
terms.append(
Term(
term_name=term_name,
term_type=term_type,
term_similar_to=term_similar_to,
term_name = term.get("term_name")
term_type = term.get("term_type")
term_similar_to = term.get("term_similar_to")
if term_name and term_type and term_similar_to:
terms.append(
Term(
term_name=term_name,
term_type=term_type,
term_similar_to=term_similar_to,
)
)
)
now_end = datetime.now()

View File

@ -28,7 +28,7 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
agent_end_time = agent_refined_end_time or agent_base_end_time
agent_base_duration = None
if agent_base_end_time:
if agent_base_end_time and agent_start_time:
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
agent_refined_duration = None
@ -38,7 +38,7 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
).total_seconds()
agent_full_duration = None
if agent_end_time:
if agent_end_time and agent_start_time:
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
agent_type = "refined" if agent_refined_duration else "base"

View File

@ -56,14 +56,14 @@ class RefinedAgentEndStats(BaseModel):
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats):
agent_start_time: datetime = datetime.now()
previous_history: str = ""
agent_start_time: datetime | None = None
previous_history: str | None = None
initial_decomp_questions: list[str] = []
class ExploratorySearchUpdate(LoggerUpdate):
exploratory_search_results: list[InferenceSection] = []
previous_history_summary: str = ""
previous_history_summary: str | None = None
class AnswerComparison(LoggerUpdate):
@ -71,24 +71,24 @@ class AnswerComparison(LoggerUpdate):
class RoutingDecision(LoggerUpdate):
routing_decision: str = ""
routing_decision: str | None = None
# Not used in current graph
class InitialAnswerBASEUpdate(BaseModel):
initial_base_answer: str = ""
initial_base_answer: str
class InitialAnswerUpdate(LoggerUpdate):
initial_answer: str = ""
initial_answer: str
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
agent_base_end_time: datetime
agent_base_metrics: AgentBaseMetrics | None = None
class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
refined_answer: str = ""
refined_answer: str
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False

View File

@ -73,7 +73,7 @@ def calculate_sub_question_retrieval_stats(
raw_chunk_stats_counts["verified_count"]
)
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
rejected_scores = raw_chunk_stats_scores.get("rejected_scores")
if rejected_scores is not None:
rejected_avg_scores = rejected_scores / float(
raw_chunk_stats_counts["rejected_count"]

View File

@ -953,7 +953,7 @@ def log_agent_metrics(
user_id: UUID | None,
persona_id: int | None, # Can be none if temporary persona is used
agent_type: str,
start_time: datetime,
start_time: datetime | None,
agent_metrics: CombinedAgentMetrics,
) -> AgentSearchMetrics:
agent_timings = agent_metrics.timings