mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-02 16:00:34 +02:00
loading object into model instead of json
This commit is contained in:
parent
a96728ff4d
commit
2adeaaeded
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import cast
|
from typing import cast
|
||||||
@ -15,12 +14,10 @@ from onyx.agents.agent_search.models import GraphConfig
|
|||||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||||
trim_prompt_piece,
|
trim_prompt_piece,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.models import Entity
|
from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionResult
|
||||||
from onyx.agents.agent_search.shared_graph_utils.models import (
|
from onyx.agents.agent_search.shared_graph_utils.models import (
|
||||||
EntityRelationshipTermExtraction,
|
EntityRelationshipTermExtraction,
|
||||||
)
|
)
|
||||||
from onyx.agents.agent_search.shared_graph_utils.models import Relationship
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.models import Term
|
|
||||||
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
from onyx.agents.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||||
@ -74,59 +71,23 @@ def extract_entities_terms(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content))
|
cleaned_response = re.sub(r"```json\n|\n```", "", str(llm_response.content))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_response = json.loads(cleaned_response)
|
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||||
except json.JSONDecodeError:
|
cleaned_response
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
||||||
parsed_response = {}
|
entity_extraction_result = EntityExtractionResult(
|
||||||
|
retrieved_entities_relationships=EntityRelationshipTermExtraction(
|
||||||
entities = []
|
entities=[],
|
||||||
relationships = []
|
relationships=[],
|
||||||
terms = []
|
terms=[],
|
||||||
for entity in parsed_response.get("retrieved_entities_relationships", {}).get(
|
),
|
||||||
"entities", {}
|
)
|
||||||
):
|
|
||||||
entity_name = entity.get("entity_name")
|
|
||||||
entity_type = entity.get("entity_type")
|
|
||||||
if entity_name and entity_type:
|
|
||||||
entities.append(Entity(entity_name=entity_name, entity_type=entity_type))
|
|
||||||
|
|
||||||
for relationship in parsed_response.get("retrieved_entities_relationships", {}).get(
|
|
||||||
"relationships", {}
|
|
||||||
):
|
|
||||||
relationship_name = relationship.get("relationship_name")
|
|
||||||
relationship_type = relationship.get("relationship_type")
|
|
||||||
relationship_entities = relationship.get("relationship_entities")
|
|
||||||
if relationship_name and relationship_type and relationship_entities:
|
|
||||||
relationships.append(
|
|
||||||
Relationship(
|
|
||||||
relationship_name=relationship_name,
|
|
||||||
relationship_type=relationship_type,
|
|
||||||
relationship_entities=relationship_entities,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for term in parsed_response.get("retrieved_entities_relationships", {}).get(
|
|
||||||
"terms", {}
|
|
||||||
):
|
|
||||||
term_name = term.get("term_name")
|
|
||||||
term_type = term.get("term_type")
|
|
||||||
term_similar_to = term.get("term_similar_to")
|
|
||||||
if term_name and term_type and term_similar_to:
|
|
||||||
terms.append(
|
|
||||||
Term(
|
|
||||||
term_name=term_name,
|
|
||||||
term_type=term_type,
|
|
||||||
term_similar_to=term_similar_to,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return EntityTermExtractionUpdate(
|
return EntityTermExtractionUpdate(
|
||||||
entity_relation_term_extractions=EntityRelationshipTermExtraction(
|
entity_relation_term_extractions=entity_extraction_result.retrieved_entities_relationships,
|
||||||
entities=entities,
|
|
||||||
relationships=relationships,
|
|
||||||
terms=terms,
|
|
||||||
),
|
|
||||||
log_messages=[
|
log_messages=[
|
||||||
get_langgraph_node_log_string(
|
get_langgraph_node_log_string(
|
||||||
graph_component="main",
|
graph_component="main",
|
||||||
|
@ -87,6 +87,10 @@ class EntityRelationshipTermExtraction(BaseModel):
|
|||||||
terms: list[Term] = []
|
terms: list[Term] = []
|
||||||
|
|
||||||
|
|
||||||
|
class EntityExtractionResult(BaseModel):
|
||||||
|
retrieved_entities_relationships: EntityRelationshipTermExtraction
|
||||||
|
|
||||||
|
|
||||||
class QueryRetrievalResult(BaseModel):
|
class QueryRetrievalResult(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
retrieved_documents: list[InferenceSection]
|
retrieved_documents: list[InferenceSection]
|
||||||
|
@ -1,12 +1,9 @@
|
|||||||
import ast
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
@ -71,13 +68,6 @@ from onyx.tools.utils import explicit_tool_calling_supported
|
|||||||
BaseMessage_Content = str | list[str | dict[str, Any]]
|
BaseMessage_Content = str | list[str | dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
def normalize_whitespace(text: str) -> str:
|
|
||||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
|
||||||
import re
|
|
||||||
|
|
||||||
return re.sub(r"\s+", " ", text.strip())
|
|
||||||
|
|
||||||
|
|
||||||
# Post-processing
|
# Post-processing
|
||||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||||
formatted_doc_list = []
|
formatted_doc_list = []
|
||||||
@ -88,43 +78,6 @@ def format_docs(docs: Sequence[InferenceSection]) -> str:
|
|||||||
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
|
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
|
||||||
|
|
||||||
|
|
||||||
def format_docs_content_flat(docs: Sequence[InferenceSection]) -> str:
|
|
||||||
formatted_doc_list = []
|
|
||||||
|
|
||||||
for _, doc in enumerate(docs):
|
|
||||||
formatted_doc_list.append(f"\n...{doc.combined_content}\n")
|
|
||||||
|
|
||||||
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
|
|
||||||
|
|
||||||
|
|
||||||
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
|
||||||
# Remove any prefixes/labels before the actual JSON content
|
|
||||||
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
|
|
||||||
|
|
||||||
# Remove markdown code block markers and any newline prefixes
|
|
||||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
|
||||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
|
||||||
cleaned_string = " ".join(cleaned_string.split())
|
|
||||||
|
|
||||||
# Try parsing with json.loads first, fall back to ast.literal_eval
|
|
||||||
try:
|
|
||||||
return json.loads(cleaned_string)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
try:
|
|
||||||
return ast.literal_eval(cleaned_string)
|
|
||||||
except (ValueError, SyntaxError) as e:
|
|
||||||
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
|
|
||||||
|
|
||||||
|
|
||||||
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
|
||||||
# Remove markdown code block markers and any newline prefixes
|
|
||||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
|
||||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
|
||||||
cleaned_string = " ".join(cleaned_string.split())
|
|
||||||
# Parse the cleaned string into a Python dictionary
|
|
||||||
return json.loads(cleaned_string)
|
|
||||||
|
|
||||||
|
|
||||||
def format_entity_term_extraction(
|
def format_entity_term_extraction(
|
||||||
entity_term_extraction_dict: EntityRelationshipTermExtraction,
|
entity_term_extraction_dict: EntityRelationshipTermExtraction,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -161,12 +114,6 @@ def format_entity_term_extraction(
|
|||||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||||
|
|
||||||
|
|
||||||
def _format_time_delta(time: timedelta) -> str:
|
|
||||||
seconds_from_start = f"{((time).seconds):03d}"
|
|
||||||
microseconds_from_start = f"{((time).microseconds):06d}"
|
|
||||||
return f"{seconds_from_start}.{microseconds_from_start}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_test_config(
|
def get_test_config(
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
primary_llm: LLM,
|
primary_llm: LLM,
|
||||||
@ -279,11 +226,10 @@ def get_persona_agent_prompt_expressions(
|
|||||||
persona: Persona | None,
|
persona: Persona | None,
|
||||||
) -> PersonaPromptExpressions:
|
) -> PersonaPromptExpressions:
|
||||||
if persona is None:
|
if persona is None:
|
||||||
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
|
||||||
persona_base = ""
|
persona_base = ""
|
||||||
|
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||||
else:
|
else:
|
||||||
persona_base = "\n".join([x.system_prompt for x in persona.prompts])
|
persona_base = "\n".join([x.system_prompt for x in persona.prompts])
|
||||||
|
|
||||||
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||||
persona_prompt=persona_base
|
persona_prompt=persona_base
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user