mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-02 08:58:11 +02:00
default values of number of strings and other things
This commit is contained in:
parent
385b344a43
commit
23ae4547ca
@ -244,23 +244,28 @@ def generate_initial_answer(
|
||||
|
||||
agent_base_end_time = datetime.now()
|
||||
|
||||
if agent_base_end_time and state.agent_start_time:
|
||||
duration__s = (agent_base_end_time - state.agent_start_time).total_seconds()
|
||||
else:
|
||||
duration__s = None
|
||||
|
||||
agent_base_metrics = AgentBaseMetrics(
|
||||
num_verified_documents_total=len(relevant_docs),
|
||||
num_verified_documents_core=state.original_question_retrieval_stats.verified_count,
|
||||
verified_avg_score_core=state.original_question_retrieval_stats.verified_avg_scores,
|
||||
num_verified_documents_base=initial_agent_stats.sub_questions.get(
|
||||
"num_verified_documents", None
|
||||
"num_verified_documents"
|
||||
),
|
||||
verified_avg_score_base=initial_agent_stats.sub_questions.get(
|
||||
"verified_avg_score", None
|
||||
"verified_avg_score"
|
||||
),
|
||||
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"utilized_chunk_ratio", None
|
||||
"utilized_chunk_ratio"
|
||||
),
|
||||
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
|
||||
"support_ratio", None
|
||||
"support_ratio"
|
||||
),
|
||||
duration__s=(agent_base_end_time - state.agent_start_time).total_seconds(),
|
||||
duration__s=duration__s,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
@ -23,8 +23,8 @@ from onyx.agents.agent_search.deep_search_a.main.nodes.create_refined_sub_questi
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.decide_refinement_need import (
|
||||
decide_refinement_need,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entity_term import (
|
||||
extract_entity_term,
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.extract_entities_terms import (
|
||||
extract_entities_terms,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search_a.main.nodes.generate_refined_answer import (
|
||||
generate_refined_answer,
|
||||
@ -116,7 +116,7 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
|
||||
graph.add_node(
|
||||
node="extract_entity_term",
|
||||
action=extract_entity_term,
|
||||
action=extract_entities_terms,
|
||||
)
|
||||
graph.add_node(
|
||||
node="decide_refinement_need",
|
||||
|
@ -25,7 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROM
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def extract_entity_term(
|
||||
def extract_entities_terms(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> EntityTermExtractionUpdate:
|
||||
now_start = datetime.now()
|
||||
@ -68,7 +68,11 @@ def extract_entity_term(
|
||||
)
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content))
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
try:
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
||||
parsed_response = {}
|
||||
|
||||
entities = []
|
||||
relationships = []
|
||||
@ -76,37 +80,40 @@ def extract_entity_term(
|
||||
for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"entities", {}
|
||||
):
|
||||
entity_name = entity.get("entity_name", "")
|
||||
entity_type = entity.get("entity_type", "")
|
||||
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
|
||||
entity_name = entity.get("entity_name")
|
||||
entity_type = entity.get("entity_type")
|
||||
if entity_name and entity_type:
|
||||
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
|
||||
|
||||
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"relationships", {}
|
||||
):
|
||||
relationship_name = relationship.get("relationship_name", "")
|
||||
relationship_type = relationship.get("relationship_type", "")
|
||||
relationship_entities = relationship.get("relationship_entities", [])
|
||||
relationships.append(
|
||||
Relationship(
|
||||
relationship_name=relationship_name,
|
||||
relationship_type=relationship_type,
|
||||
relationship_entities=relationship_entities,
|
||||
relationship_name = relationship.get("relationship_name")
|
||||
relationship_type = relationship.get("relationship_type")
|
||||
relationship_entities = relationship.get("relationship_entities")
|
||||
if relationship_name and relationship_type and relationship_entities:
|
||||
relationships.append(
|
||||
Relationship(
|
||||
relationship_name=relationship_name,
|
||||
relationship_type=relationship_type,
|
||||
relationship_entities=relationship_entities,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"terms", {}
|
||||
):
|
||||
term_name = term.get("term_name", "")
|
||||
term_type = term.get("term_type", "")
|
||||
term_similar_to = term.get("term_similar_to", [])
|
||||
terms.append(
|
||||
Term(
|
||||
term_name=term_name,
|
||||
term_type=term_type,
|
||||
term_similar_to=term_similar_to,
|
||||
term_name = term.get("term_name")
|
||||
term_type = term.get("term_type")
|
||||
term_similar_to = term.get("term_similar_to")
|
||||
if term_name and term_type and term_similar_to:
|
||||
terms.append(
|
||||
Term(
|
||||
term_name=term_name,
|
||||
term_type=term_type,
|
||||
term_similar_to=term_similar_to,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
|
@ -28,7 +28,7 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
|
||||
agent_end_time = agent_refined_end_time or agent_base_end_time
|
||||
|
||||
agent_base_duration = None
|
||||
if agent_base_end_time:
|
||||
if agent_base_end_time and agent_start_time:
|
||||
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_refined_duration = None
|
||||
@ -38,7 +38,7 @@ def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutpu
|
||||
).total_seconds()
|
||||
|
||||
agent_full_duration = None
|
||||
if agent_end_time:
|
||||
if agent_end_time and agent_start_time:
|
||||
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
|
||||
|
||||
agent_type = "refined" if agent_refined_duration else "base"
|
||||
|
@ -56,14 +56,14 @@ class RefinedAgentEndStats(BaseModel):
|
||||
|
||||
|
||||
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats):
|
||||
agent_start_time: datetime = datetime.now()
|
||||
previous_history: str = ""
|
||||
agent_start_time: datetime | None = None
|
||||
previous_history: str | None = None
|
||||
initial_decomp_questions: list[str] = []
|
||||
|
||||
|
||||
class ExploratorySearchUpdate(LoggerUpdate):
|
||||
exploratory_search_results: list[InferenceSection] = []
|
||||
previous_history_summary: str = ""
|
||||
previous_history_summary: str | None = None
|
||||
|
||||
|
||||
class AnswerComparison(LoggerUpdate):
|
||||
@ -71,24 +71,24 @@ class AnswerComparison(LoggerUpdate):
|
||||
|
||||
|
||||
class RoutingDecision(LoggerUpdate):
|
||||
routing_decision: str = ""
|
||||
routing_decision: str | None = None
|
||||
|
||||
|
||||
# Not used in current graph
|
||||
class InitialAnswerBASEUpdate(BaseModel):
|
||||
initial_base_answer: str = ""
|
||||
initial_base_answer: str
|
||||
|
||||
|
||||
class InitialAnswerUpdate(LoggerUpdate):
|
||||
initial_answer: str = ""
|
||||
initial_answer: str
|
||||
initial_agent_stats: InitialAgentResultStats | None = None
|
||||
generated_sub_questions: list[str] = []
|
||||
agent_base_end_time: datetime | None = None
|
||||
agent_base_end_time: datetime
|
||||
agent_base_metrics: AgentBaseMetrics | None = None
|
||||
|
||||
|
||||
class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
|
||||
refined_answer: str = ""
|
||||
refined_answer: str
|
||||
refined_agent_stats: RefinedAgentStats | None = None
|
||||
refined_answer_quality: bool = False
|
||||
|
||||
|
@ -73,7 +73,7 @@ def calculate_sub_question_retrieval_stats(
|
||||
raw_chunk_stats_counts["verified_count"]
|
||||
)
|
||||
|
||||
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
|
||||
rejected_scores = raw_chunk_stats_scores.get("rejected_scores")
|
||||
if rejected_scores is not None:
|
||||
rejected_avg_scores = rejected_scores / float(
|
||||
raw_chunk_stats_counts["rejected_count"]
|
||||
|
@ -953,7 +953,7 @@ def log_agent_metrics(
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
agent_type: str,
|
||||
start_time: datetime,
|
||||
start_time: datetime | None,
|
||||
agent_metrics: CombinedAgentMetrics,
|
||||
) -> AgentSearchMetrics:
|
||||
agent_timings = agent_metrics.timings
|
||||
|
Loading…
x
Reference in New Issue
Block a user