mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-02 16:00:34 +02:00
loading object into model instead of json
This commit is contained in:
parent
a96728ff4d
commit
2adeaaeded
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
@ -15,12 +14,10 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Entity
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionResult
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||
EntityRelationshipTermExtraction,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import Term
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
@ -74,59 +71,23 @@ def extract_entities_terms(
|
||||
)
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content))
|
||||
|
||||
try:
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
except json.JSONDecodeError:
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
||||
parsed_response = {}
|
||||
|
||||
entities = []
|
||||
relationships = []
|
||||
terms = []
|
||||
for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"entities", {}
|
||||
):
|
||||
entity_name = entity.get("entity_name")
|
||||
entity_type = entity.get("entity_type")
|
||||
if entity_name and entity_type:
|
||||
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
|
||||
|
||||
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"relationships", {}
|
||||
):
|
||||
relationship_name = relationship.get("relationship_name")
|
||||
relationship_type = relationship.get("relationship_type")
|
||||
relationship_entities = relationship.get("relationship_entities")
|
||||
if relationship_name and relationship_type and relationship_entities:
|
||||
relationships.append(
|
||||
Relationship(
|
||||
relationship_name=relationship_name,
|
||||
relationship_type=relationship_type,
|
||||
relationship_entities=relationship_entities,
|
||||
)
|
||||
)
|
||||
|
||||
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
|
||||
"terms", {}
|
||||
):
|
||||
term_name = term.get("term_name")
|
||||
term_type = term.get("term_type")
|
||||
term_similar_to = term.get("term_similar_to")
|
||||
if term_name and term_type and term_similar_to:
|
||||
terms.append(
|
||||
Term(
|
||||
term_name=term_name,
|
||||
term_type=term_type,
|
||||
term_similar_to=term_similar_to,
|
||||
)
|
||||
)
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
entity_relation_term_extractions=EntityRelationshipTermExtraction(
|
||||
entities=entities,
|
||||
relationships=relationships,
|
||||
terms=terms,
|
||||
),
|
||||
entity_relation_term_extractions=entity_extraction_result.retrieved_entities_relationships,
|
||||
log_messages=[
|
||||
get_langgraph_node_log_string(
|
||||
graph_component="main",
|
||||
|
@ -87,6 +87,10 @@ class EntityRelationshipTermExtraction(BaseModel):
|
||||
terms: list[Term] = []
|
||||
|
||||
|
||||
class EntityExtractionResult(BaseModel):
|
||||
retrieved_entities_relationships: EntityRelationshipTermExtraction
|
||||
|
||||
|
||||
class QueryRetrievalResult(BaseModel):
|
||||
query: str
|
||||
retrieved_documents: list[InferenceSection]
|
||||
|
@ -1,12 +1,9 @@
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
@ -71,13 +68,6 @@ from onyx.tools.utils import explicit_tool_calling_supported
|
||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
||||
import re
|
||||
|
||||
return re.sub(r"\s+", " ", text.strip())
|
||||
|
||||
|
||||
# Post-processing
|
||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
formatted_doc_list = []
|
||||
@ -88,43 +78,6 @@ def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
|
||||
|
||||
|
||||
def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str:
|
||||
formatted_doc_list = []
|
||||
|
||||
for _, doc in enumerate(docs):
|
||||
formatted_doc_list.append(f"\n...{doc.combined_content}\n")
|
||||
|
||||
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
|
||||
|
||||
|
||||
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
||||
# Remove any prefixes/labels before the actual JSON content
|
||||
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
|
||||
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
|
||||
# Try parsing with json.loads first, fall back to ast.literal_eval
|
||||
try:
|
||||
return json.loads(cleaned_string)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
return ast.literal_eval(cleaned_string)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
|
||||
|
||||
|
||||
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return json.loads(cleaned_string)
|
||||
|
||||
|
||||
def format_entity_term_extraction(
|
||||
entity_term_extraction_dict: EntityRelationshipTermExtraction,
|
||||
) -> str:
|
||||
@ -161,12 +114,6 @@ def format_entity_term_extraction(
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def _format_time_delta(time: timedelta) -> str:
|
||||
seconds_from_start = f"{((time).seconds):03d}"
|
||||
microseconds_from_start = f"{((time).microseconds):06d}"
|
||||
return f"{seconds_from_start}.{microseconds_from_start}"
|
||||
|
||||
|
||||
def get_test_config(
|
||||
db_session: Session,
|
||||
primary_llm: LLM,
|
||||
@ -279,11 +226,10 @@ def get_persona_agent_prompt_expressions(
|
||||
persona: Persona | None,
|
||||
) -> PersonaPromptExpressions:
|
||||
if persona is None:
|
||||
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
persona_base = ""
|
||||
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
else:
|
||||
persona_base = "\n".join([x.system_prompt for x in persona.prompts])
|
||||
|
||||
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_base
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user