From 2adeaaeded702e6ccaba23146dd258119f599bdf Mon Sep 17 00:00:00 2001 From: Evan Lohn Date: Sat, 1 Feb 2025 23:09:09 -0800 Subject: [PATCH] loading object into model instead of json --- .../main/nodes/extract_entities_terms.py | 67 ++++--------------- .../agent_search/shared_graph_utils/models.py | 4 ++ .../agent_search/shared_graph_utils/utils.py | 56 +--------------- 3 files changed, 19 insertions(+), 108 deletions(-) diff --git a/backend/onyx/agents/agent_search/deep_search/main/nodes/extract_entities_terms.py b/backend/onyx/agents/agent_search/deep_search/main/nodes/extract_entities_terms.py index 33676ddb5..23e60c24b 100644 --- a/backend/onyx/agents/agent_search/deep_search/main/nodes/extract_entities_terms.py +++ b/backend/onyx/agents/agent_search/deep_search/main/nodes/extract_entities_terms.py @@ -1,4 +1,3 @@ -import json import re from datetime import datetime from typing import cast @@ -15,12 +14,10 @@ from onyx.agents.agent_search.models import GraphConfig from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( trim_prompt_piece, ) -from onyx.agents.agent_search.shared_graph_utils.models import Entity +from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionResult from onyx.agents.agent_search.shared_graph_utils.models import ( EntityRelationshipTermExtraction, ) -from onyx.agents.agent_search.shared_graph_utils.models import Relationship -from onyx.agents.agent_search.shared_graph_utils.models import Term from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT from onyx.agents.agent_search.shared_graph_utils.utils import format_docs from onyx.agents.agent_search.shared_graph_utils.utils import ( @@ -74,59 +71,23 @@ def extract_entities_terms( ) cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content)) + try: - parsed_response = json.loads(cleaned_response) - except json.JSONDecodeError: + entity_extraction_result = EntityExtractionResult.model_validate_json( + cleaned_response + ) + except ValueError: logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction") - parsed_response = {} - - entities = [] - relationships = [] - terms = [] - for entity in parsed_response.get("retrieved_entities_relationships", {}).get( - "entities", {} - ): - entity_name = entity.get("entity_name") - entity_type = entity.get("entity_type") - if entity_name and entity_type: - entities.append(Entity(entity_name=entity_name, entity_type=entity_type)) - - for relationship in parsed_response.get("retrieved_entities_relationships", {}).get( - "relationships", {} - ): - relationship_name = relationship.get("relationship_name") - relationship_type = relationship.get("relationship_type") - relationship_entities = relationship.get("relationship_entities") - if relationship_name and relationship_type and relationship_entities: - relationships.append( - Relationship( - relationship_name=relationship_name, - relationship_type=relationship_type, - relationship_entities=relationship_entities, - ) - ) - - for term in parsed_response.get("retrieved_entities_relationships", {}).get( - "terms", {} - ): - term_name = term.get("term_name") - term_type = term.get("term_type") - term_similar_to = term.get("term_similar_to") - if term_name and term_type and term_similar_to: - terms.append( - Term( - term_name=term_name, - term_type=term_type, - term_similar_to=term_similar_to, - ) - ) + entity_extraction_result = EntityExtractionResult( + retrieved_entities_relationships=EntityRelationshipTermExtraction( + entities=[], + relationships=[], + terms=[], + ), + ) return EntityTermExtractionUpdate( - entity_relation_term_extractions=EntityRelationshipTermExtraction( - entities=entities, - relationships=relationships, - terms=terms, - ), + entity_relation_term_extractions=entity_extraction_result.retrieved_entities_relationships, log_messages=[ get_langgraph_node_log_string( graph_component="main", diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/models.py b/backend/onyx/agents/agent_search/shared_graph_utils/models.py index a07669193..a1f480874 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/models.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/models.py @@ -87,6 +87,10 @@ class EntityRelationshipTermExtraction(BaseModel): terms: list[Term] = [] +class EntityExtractionResult(BaseModel): + retrieved_entities_relationships: EntityRelationshipTermExtraction + + class QueryRetrievalResult(BaseModel): query: str retrieved_documents: list[InferenceSection] diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index 9bc8e6d4e..dce3c13d8 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -1,12 +1,9 @@ -import ast -import json import os import re from collections.abc import Callable from collections.abc import Iterator from collections.abc import Sequence from datetime import datetime -from datetime import timedelta from typing import Any from typing import cast from typing import Literal @@ -71,13 +68,6 @@ from onyx.tools.utils import explicit_tool_calling_supported BaseMessage_Content = str | list[str | dict[str, Any]] -def normalize_whitespace(text: str) -> str: - """Normalize whitespace in text to single spaces and strip leading/trailing whitespace.""" - import re - - return re.sub(r"\s+", " ", text.strip()) - - # Post-processing def format_docs(docs: Sequence[InferenceSection]) -> str: formatted_doc_list = [] @@ -88,43 +78,6 @@ def format_docs(docs: Sequence[InferenceSection]) -> str: return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list) -def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str: - formatted_doc_list = [] - - for _, doc in enumerate(docs): - formatted_doc_list.append(f"\n...{doc.combined_content}\n") - - return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list) - - -def clean_and_parse_list_string(json_string: str) -> list[dict]: - # Remove any prefixes/labels before the actual JSON content - json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL) - - # Remove markdown code block markers and any newline prefixes - cleaned_string = re.sub(r"```json\n|\n```", "", json_string) - cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") - cleaned_string = " ".join(cleaned_string.split()) - - # Try parsing with json.loads first, fall back to ast.literal_eval - try: - return json.loads(cleaned_string) - except json.JSONDecodeError: - try: - return ast.literal_eval(cleaned_string) - except (ValueError, SyntaxError) as e: - raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e - - -def clean_and_parse_json_string(json_string: str) -> dict[str, Any]: - # Remove markdown code block markers and any newline prefixes - cleaned_string = re.sub(r"```json\n|\n```", "", json_string) - cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ") - cleaned_string = " ".join(cleaned_string.split()) - # Parse the cleaned string into a Python dictionary - return json.loads(cleaned_string) - - def format_entity_term_extraction( entity_term_extraction_dict: EntityRelationshipTermExtraction, ) -> str: @@ -161,12 +114,6 @@ def format_entity_term_extraction( return "\n".join(entity_strs + relationship_strs + term_strs) -def _format_time_delta(time: timedelta) -> str: - seconds_from_start = f"{((time).seconds):03d}" - microseconds_from_start = f"{((time).microseconds):06d}" - return f"{seconds_from_start}.{microseconds_from_start}" - - def get_test_config( db_session: Session, primary_llm: LLM, @@ -279,11 +226,10 @@ def get_persona_agent_prompt_expressions( persona: Persona | None, ) -> PersonaPromptExpressions: if persona is None: - persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT persona_base = "" + persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT else: persona_base = "\n".join([x.system_prompt for x in persona.prompts]) - persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format( persona_prompt=persona_base )