loading object into model instead of json

This commit is contained in:
Evan Lohn 2025-02-01 23:09:09 -08:00
parent a96728ff4d
commit 2adeaaeded
3 changed files with 19 additions and 108 deletions

View File

@ -1,4 +1,3 @@
import json
import re import re
from datetime import datetime from datetime import datetime
from typing import cast from typing import cast
@ -15,12 +14,10 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece, trim_prompt_piece,
) )
from onyx.agents.agent_search.shared_graph_utils.models import Entity from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionResult
from onyx.agents.agent_search.shared_graph_utils.models import ( from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction, EntityRelationshipTermExtraction,
) )
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
from onyx.agents.agent_search.shared_graph_utils.models import Term
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import ( from onyx.agents.agent_search.shared_graph_utils.utils import (
@ -74,59 +71,23 @@ def extract_entities_terms(
) )
cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content)) cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content))
try: try:
parsed_response = json.loads(cleaned_response) entity_extraction_result = EntityExtractionResult.model_validate_json(
except json.JSONDecodeError: cleaned_response
)
except ValueError:
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction") logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
parsed_response = {} entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(
entities = [] entities=[],
relationships = [] relationships=[],
terms = [] terms=[],
for entity in parsed_response.get("retrieved_entities_relationships", {}).get( ),
"entities", {} )
):
entity_name = entity.get("entity_name")
entity_type = entity.get("entity_type")
if entity_name and entity_type:
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
"relationships", {}
):
relationship_name = relationship.get("relationship_name")
relationship_type = relationship.get("relationship_type")
relationship_entities = relationship.get("relationship_entities")
if relationship_name and relationship_type and relationship_entities:
relationships.append(
Relationship(
relationship_name=relationship_name,
relationship_type=relationship_type,
relationship_entities=relationship_entities,
)
)
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
"terms", {}
):
term_name = term.get("term_name")
term_type = term.get("term_type")
term_similar_to = term.get("term_similar_to")
if term_name and term_type and term_similar_to:
terms.append(
Term(
term_name=term_name,
term_type=term_type,
term_similar_to=term_similar_to,
)
)
return EntityTermExtractionUpdate( return EntityTermExtractionUpdate(
entity_relation_term_extractions=EntityRelationshipTermExtraction( entity_relation_term_extractions=entity_extraction_result.retrieved_entities_relationships,
entities=entities,
relationships=relationships,
terms=terms,
),
log_messages=[ log_messages=[
get_langgraph_node_log_string( get_langgraph_node_log_string(
graph_component="main", graph_component="main",

View File

@ -87,6 +87,10 @@ class EntityRelationshipTermExtraction(BaseModel):
terms: list[Term] = [] terms: list[Term] = []
class EntityExtractionResult(BaseModel):
retrieved_entities_relationships: EntityRelationshipTermExtraction
class QueryRetrievalResult(BaseModel): class QueryRetrievalResult(BaseModel):
query: str query: str
retrieved_documents: list[InferenceSection] retrieved_documents: list[InferenceSection]

View File

@ -1,12 +1,9 @@
import ast
import json
import os import os
import re import re
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
from datetime import timedelta
from typing import Any from typing import Any
from typing import cast from typing import cast
from typing import Literal from typing import Literal
@ -71,13 +68,6 @@ from onyx.tools.utils import explicit_tool_calling_supported
BaseMessage_Content = str | list[str | dict[str, Any]] BaseMessage_Content = str | list[str | dict[str, Any]]
def normalize_whitespace(text: str) -> str:
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
import re
return re.sub(r"\s+", " ", text.strip())
# Post-processing # Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str: def format_docs(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = [] formatted_doc_list = []
@ -88,43 +78,6 @@ def format_docs(docs: Sequence[InferenceSection]) -> str:
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list) return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for _, doc in enumerate(docs):
formatted_doc_list.append(f"\n...{doc.combined_content}\n")
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
def clean_and_parse_list_string(json_string: str) -> list[dict]:
# Remove any prefixes/labels before the actual JSON content
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Try parsing with json.loads first, fall back to ast.literal_eval
try:
return json.loads(cleaned_string)
except json.JSONDecodeError:
try:
return ast.literal_eval(cleaned_string)
except (ValueError, SyntaxError) as e:
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return json.loads(cleaned_string)
def format_entity_term_extraction( def format_entity_term_extraction(
entity_term_extraction_dict: EntityRelationshipTermExtraction, entity_term_extraction_dict: EntityRelationshipTermExtraction,
) -> str: ) -> str:
@ -161,12 +114,6 @@ def format_entity_term_extraction(
return "\n".join(entity_strs + relationship_strs + term_strs) return "\n".join(entity_strs + relationship_strs + term_strs)
def _format_time_delta(time: timedelta) -> str:
seconds_from_start = f"{((time).seconds):03d}"
microseconds_from_start = f"{((time).microseconds):06d}"
return f"{seconds_from_start}.{microseconds_from_start}"
def get_test_config( def get_test_config(
db_session: Session, db_session: Session,
primary_llm: LLM, primary_llm: LLM,
@ -279,11 +226,10 @@ def get_persona_agent_prompt_expressions(
persona: Persona | None, persona: Persona | None,
) -> PersonaPromptExpressions: ) -> PersonaPromptExpressions:
if persona is None: if persona is None:
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
persona_base = "" persona_base = ""
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else: else:
persona_base = "\n".join([x.system_prompt for x in persona.prompts]) persona_base = "\n".join([x.system_prompt for x in persona.prompts])
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format( persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_base persona_prompt=persona_base
) )