application of content limitation ion refined answer as well

This commit is contained in:
joachim-danswer 2025-01-27 17:19:24 -08:00 committed by Evan Lohn
parent f2aeeb7b3c
commit 18d92559b5
13 changed files with 99 additions and 172 deletions

View File

@ -16,15 +16,9 @@ from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import NO_RECOVERED_DOCS
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
@ -45,7 +39,7 @@ def answer_generation(
docs = state.documents
level, question_nr = parse_question_id(state.question_id)
context_docs = state.context_documents
persona_prompt = get_persona_prompt(agent_search_config.search_request.persona)
persona = get_persona_expressions(agent_search_config.search_request.persona)
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
@ -59,13 +53,6 @@ def answer_generation(
),
)
else:
if len(persona_prompt) > 0:
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else:
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_prompt
)
logger.debug(f"Number of verified retrieval docs: {len(docs)}")
fast_llm = agent_search_config.fast_llm
@ -73,7 +60,7 @@ def answer_generation(
question=question,
original_question=agent_search_config.search_request.query,
docs=context_docs,
persona_specification=persona_specification,
persona_specification=persona.persona_prompt,
config=fast_llm.config,
)

View File

@ -31,12 +31,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResul
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
@ -49,7 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
@ -71,9 +65,9 @@ def generate_initial_answer(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
persona = get_persona_expressions(agent_a_config.search_request.persona)
history = build_history_prompt(agent_a_config.prompt_builder)
history = build_history_prompt(agent_a_config, question)
date_str = get_today_prompt()
@ -81,7 +75,7 @@ def generate_initial_answer(
sub_questions_cited_docs = state.cited_docs
all_original_question_documents = state.all_original_question_documents
consolidated_context_docs: list[InferenceSection] = []
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_docs
counter = 0
for original_doc_number, original_doc in enumerate(all_original_question_documents):
if original_doc_number not in sub_questions_cited_docs:
@ -174,15 +168,6 @@ def generate_initial_answer(
else:
sub_question_answer_str = ""
# Determine which persona-specification prompt to use
if len(persona_prompt) == 0:
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else:
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_prompt
)
# Determine which base prompt to use given the sub-question information
if len(good_qa_list) > 0:
base_prompt = INITIAL_RAG_PROMPT
@ -193,7 +178,7 @@ def generate_initial_answer(
# summarize the history iff too long
if len(history) > AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH:
history = summarize_history(history, question, persona_specification, model)
history = summarize_history(history, question, persona.persona_base, model)
doc_context = format_docs(relevant_docs)
doc_context = trim_prompt_piece(
@ -201,7 +186,7 @@ def generate_initial_answer(
doc_context,
base_prompt
+ sub_question_answer_str
+ persona_specification
+ persona.persona_prompt
+ history
+ date_str,
)
@ -214,7 +199,7 @@ def generate_initial_answer(
sub_question_answer_str
),
relevant_docs=format_docs(relevant_docs),
persona_specification=persona_specification,
persona_specification=persona.persona_prompt,
history=history,
date_prompt=date_str,
)

View File

@ -44,10 +44,10 @@ def initial_sub_question_creation(
perform_initial_search_decomposition = (
agent_a_config.perform_initial_search_decomposition
)
# perform_initial_search_path_decision = (
# agent_a_config.perform_initial_search_path_decision
# )
history = build_history_prompt(agent_a_config.prompt_builder)
# Get the rewritten queries in a defined format
model = agent_a_config.fast_llm
history = build_history_prompt(agent_a_config, question)
# Use the initial search results to inform the decomposition
sample_doc_str = state.sample_doc_str if hasattr(state, "sample_doc_str") else ""
@ -85,9 +85,6 @@ def initial_sub_question_creation(
msg = [HumanMessage(content=decomposition_prompt)]
# Get the rewritten queries in a defined format
model = agent_a_config.fast_llm
# Send the initial question as a subquestion with number 0
dispatch_custom_event(
"decomp_qs",

View File

@ -109,7 +109,6 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
],
)
logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
logger.debug(f"--------{now_end}--{now_end - now_start}--------LOGGING NODE END---")
return main_output

View File

@ -1,77 +1,25 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.deep_search_a.main.states import RoutingDecision
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import AGENT_DECISION_PROMPT
from onyx.llm.utils import check_number_of_tokens
def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDecision:
now_start = datetime.now()
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
agent_a_config.search_request.query
# perform_initial_search_path_decision = (
# agent_a_config.perform_initial_search_path_decision
# )
history = build_history_prompt(agent_a_config.prompt_builder)
logger.debug(f"--------{now_start}--------DECIDING TO SEARCH OR GO TO LLM---")
# if perform_initial_search_path_decision:
# search_tool = agent_a_config.search_tool
# retrieved_docs: list[InferenceSection] = []
# # new db session to avoid concurrency issues
# with get_session_context_manager() as db_session:
# for tool_response in search_tool.run(
# query=question,
# force_no_rerank=True,
# alternate_db_session=db_session,
# ):
# # get retrieved docs to send to the rest of the graph
# if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
# response = cast(SearchResponseSummary, tool_response.response)
# retrieved_docs = response.top_sections
# break
# sample_doc_str = "\n\n".join(
# [doc.combined_content for _, doc in enumerate(retrieved_docs[:3])]
# )
# agent_decision_prompt = AGENT_DECISION_PROMPT_AFTER_SEARCH.format(
# question=question, sample_doc_str=sample_doc_str, history=history
# )
# else:
sample_doc_str = ""
agent_decision_prompt = AGENT_DECISION_PROMPT.format(
question=question, history=history
)
msg = [HumanMessage(content=agent_decision_prompt)]
# Get the rewritten queries in a defined format
model = agent_a_config.fast_llm
# no need to stream this
resp = model.invoke(msg)
if isinstance(resp.content, str) and "research" in resp.content.lower():
routing = "agent_search"
else:
routing = "LLM"
routing = "agent_search"
now_end = datetime.now()
@ -79,14 +27,10 @@ def agent_path_decision(state: MainState, config: RunnableConfig) -> RoutingDeci
logger.debug(
f"--------{now_end}--{now_end - now_start}--------DECIDING TO SEARCH OR GO TO LLM END---"
)
check_number_of_tokens(agent_decision_prompt)
return RoutingDecision(
# Decide which route to take
routing=routing,
sample_doc_str=sample_doc_str,
log_messages=[
f"{now_start} -- Path decision: {routing}, Time taken: {now_end - now_start}"
f"{now_end} -- Path decision: {routing}, Time taken: {now_end - now_start}"
],
)

View File

@ -7,6 +7,9 @@ from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import ExploratorySearchUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import retrieve_search_docs
from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
from onyx.context.search.models import InferenceSection
@ -23,6 +26,9 @@ def agent_search_start(
question = agent_a_config.search_request.query
chat_session_id = agent_a_config.chat_session_id
primary_message_id = agent_a_config.message_id
agent_a_config.fast_llm
history = build_history_prompt(agent_a_config, question)
if chat_session_id is None or primary_message_id is None:
raise ValueError(
@ -44,6 +50,7 @@ def agent_search_start(
return ExploratorySearchUpdate(
exploratory_search_results=exploratory_search_results,
previous_history=history,
log_messages=[
f"{now_start} -- Main - Exploratory Search, Time taken: {now_end - now_start}"
],

View File

@ -11,14 +11,8 @@ from onyx.agents.agent_search.deep_search_a.main.operations import logger
from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpdate
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
from onyx.chat.models import AgentAnswerPiece
@ -29,14 +23,7 @@ def direct_llm_handling(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
if len(persona_prompt) == 0:
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else:
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_prompt
)
persona = get_persona_expressions(agent_a_config.search_request.persona)
logger.debug(f"--------{now_start}--------LLM HANDLING START---")
@ -45,7 +32,7 @@ def direct_llm_handling(
msg = [
HumanMessage(
content=DIRECT_LLM_PROMPT.format(
persona_specification=persona_specification, question=question
persona_specification=persona.persona_prompt, question=question
)
)
]

View File

@ -22,16 +22,11 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import REVISED_RAG_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
REVISED_RAG_PROMPT_NO_SUB_QUESTIONS,
@ -44,11 +39,13 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
@ -61,14 +58,36 @@ def generate_refined_answer(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona_prompt = get_persona_prompt(agent_a_config.search_request.persona)
persona = get_persona_expressions(agent_a_config.search_request.persona)
history = build_history_prompt(agent_a_config.prompt_builder)
history = build_history_prompt(agent_a_config, question)
date_str = get_today_prompt()
initial_documents = state.documents
revised_documents = state.refined_documents
sub_questions_cited_docs = state.cited_docs
combined_documents = dedup_inference_sections(initial_documents, revised_documents)
state.context_documents
sub_questions_cited_docs = state.cited_docs
all_original_question_documents = state.all_original_question_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_docs
counter = 0
for original_doc_number, original_doc in enumerate(all_original_question_documents):
if original_doc_number not in sub_questions_cited_docs:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_sections(
consolidated_context_docs, consolidated_context_docs
)
combined_documents = relevant_docs
# combined_documents = dedup_inference_sections(initial_documents, revised_documents)
query_info = get_query_info(state.original_question_retrieval_results)
if agent_a_config.search_tool is None:
@ -157,13 +176,6 @@ def generate_refined_answer(
# Determine which persona-specification prompt to use
if len(persona_prompt) == 0:
persona_specification = ASSISTANT_SYSTEM_PROMPT_DEFAULT
else:
persona_specification = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_prompt
)
# Determine which base prompt to use given the sub-question information
if len(good_qa_list) > 0:
base_prompt = REVISED_RAG_PROMPT
@ -171,16 +183,15 @@ def generate_refined_answer(
base_prompt = REVISED_RAG_PROMPT_NO_SUB_QUESTIONS
model = agent_a_config.fast_llm
relevant_docs = format_docs(combined_documents)
relevant_docs = trim_prompt_piece(
relevant_docs_str = format_docs(combined_documents)
relevant_docs_str = trim_prompt_piece(
model.config,
relevant_docs,
relevant_docs_str,
base_prompt
+ question
+ sub_question_answer_str
+ relevant_docs
+ initial_answer
+ persona_specification
+ persona.persona_prompt
+ history,
)
@ -194,7 +205,7 @@ def generate_refined_answer(
),
relevant_docs=relevant_docs,
initial_answer=remove_document_citations(initial_answer),
persona_specification=persona_specification,
persona_specification=persona.persona_prompt,
date_prompt=date_str,
)
)
@ -229,34 +240,14 @@ def generate_refined_answer(
# state.decomp_answer_results, state.original_question_retrieval_stats
# )
initial_good_sub_questions_str = "\n".join(list(set(initial_good_sub_questions)))
new_revised_good_sub_questions_str = "\n".join(
list(set(new_revised_good_sub_questions))
)
refined_agent_stats = RefinedAgentStats(
revision_doc_efficiency=revision_doc_effectiveness,
revision_question_efficiency=revision_question_efficiency,
)
logger.debug(
f"\n\n---INITIAL ANSWER START---\n\n Answer:\n Agent: {initial_answer}"
)
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
logger.debug("-" * 10)
logger.debug(f"\n\n---REVISED AGENT ANSWER START---\n\n Answer:\n Agent: {answer}")
logger.debug("-" * 100)
logger.debug(f"\n\nINITAL Sub-Questions\n\n{initial_good_sub_questions_str}\n\n")
logger.debug("-" * 10)
logger.debug(
f"\n\nNEW REVISED Sub-Questions\n\n{new_revised_good_sub_questions_str}\n\n"
)
logger.debug("-" * 100)
logger.debug(
f"\n\nINITAL & REVISED Sub-Questions & Answers:\n\n{sub_question_answer_str}\n\nStas:\n\n"
)
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
logger.debug("-" * 100)

View File

@ -52,7 +52,7 @@ def refined_sub_question_creation(
question = agent_a_config.search_request.query
base_answer = state.initial_answer
history = build_history_prompt(agent_a_config.prompt_builder)
history = build_history_prompt(agent_a_config, question)
# get the entity term extraction dict and properly format it
entity_retlation_term_extractions = state.entity_relation_term_extractions

View File

@ -53,11 +53,13 @@ class RefinedAgentEndStats(BaseModel):
class BaseDecompUpdate(RefinedAgentStartStats, RefinedAgentEndStats):
agent_start_time: datetime = datetime.now()
previous_history: str = ""
initial_decomp_questions: list[str] = []
class ExploratorySearchUpdate(LoggerUpdate):
exploratory_search_results: list[InferenceSection] = []
previous_history: str = ""
class AnswerComparison(LoggerUpdate):
@ -66,7 +68,6 @@ class AnswerComparison(LoggerUpdate):
class RoutingDecision(LoggerUpdate):
routing: str = ""
sample_doc_str: str = ""
class InitialAnswerBASEUpdate(BaseModel):

View File

@ -3,10 +3,13 @@ from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from langchain_core.messages.tool import ToolMessage
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH
from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLMConfig
from onyx.llm.utils import get_max_input_tokens
@ -73,7 +76,11 @@ def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -
)
def build_history_prompt(prompt_builder: AnswerPromptBuilder | None) -> str:
def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
prompt_builder = config.prompt_builder
model = config.fast_llm
persona_base = get_persona_expressions(config.search_request.persona).persona_base
if prompt_builder is None:
return ""
@ -97,4 +104,8 @@ def build_history_prompt(prompt_builder: AnswerPromptBuilder | None) -> str:
else:
continue
history = "\n".join(history_components)
if len(history) > AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH:
history = summarize_history(history, question, persona_base, model)
return HISTORY_PROMPT.format(history=history) if history else ""

View File

@ -112,3 +112,8 @@ class CombinedAgentMetrics(BaseModel):
base_metrics: AgentBaseMetrics | None
refined_metrics: AgentRefinedMetrics
additional_metrics: AgentAdditionalMetrics
class PersonaExpressions(BaseModel):
persona_prompt: str
persona_base: str

View File

@ -19,6 +19,13 @@ from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import PersonaExpressions
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.agents.agent_search.shared_graph_utils.prompts import DATE_PROMPT
from onyx.agents.agent_search.shared_graph_utils.prompts import (
HISTORY_CONTEXT_SUMMARY_PROMPT,
@ -250,11 +257,17 @@ def get_test_config(
return config, search_tool
def get_persona_prompt(persona: Persona | None) -> str:
def get_persona_expressions(persona: Persona | None) -> PersonaExpressions:
if persona is None:
return ""
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
persona_base = ""
else:
return "\n".join([x.system_prompt for x in persona.prompts])
persona_base = "\n".join([x.system_prompt for x in persona.prompts])
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_base
)
return PersonaExpressions(persona_prompt=persona_prompt, persona_base=persona_base)
def make_question_id(level: int, question_nr: int) -> str: