diff --git a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py index 5ae2f582534d..0bf87b094046 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py +++ b/backend/onyx/agents/agent_search/deep_search_a/answer_initial_sub_question/nodes/answer_generation.py @@ -18,7 +18,9 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import ( ) from onyx.agents.agent_search.shared_graph_utils.prompts import NO_RECOVERED_DOCS from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids -from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions +from onyx.agents.agent_search.shared_graph_utils.utils import ( + get_persona_agent_prompt_expressions, +) from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import StreamStopInfo @@ -40,7 +42,9 @@ def answer_generation( state.documents level, question_nr = parse_question_id(state.question_id) context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS] - persona = get_persona_expressions(agent_search_config.search_request.persona) + persona_contextualized_prompt = get_persona_agent_prompt_expressions( + agent_search_config.search_request.persona + ).contextualized_prompt if len(context_docs) == 0: answer_str = NO_RECOVERED_DOCS @@ -61,7 +65,7 @@ def answer_generation( question=question, original_question=agent_search_config.search_request.query, docs=context_docs, - persona_specification=persona.persona_prompt, + persona_specification=persona_contextualized_prompt, config=fast_llm.config, ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/initial_search_sq_subgraph/nodes/generate_initial_answer.py b/backend/onyx/agents/agent_search/deep_search_a/initial_search_sq_subgraph/nodes/generate_initial_answer.py index 57e2be47cec4..260eb7b3bbf9 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/initial_search_sq_subgraph/nodes/generate_initial_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/initial_search_sq_subgraph/nodes/generate_initial_answer.py @@ -43,7 +43,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( dispatch_main_answer_stop_info, ) from onyx.agents.agent_search.shared_graph_utils.utils import format_docs -from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions +from onyx.agents.agent_search.shared_graph_utils.utils import ( + get_persona_agent_prompt_expressions, +) from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history @@ -65,7 +67,9 @@ def generate_initial_answer( agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) question = agent_a_config.search_request.query - persona = get_persona_expressions(agent_a_config.search_request.persona) + persona_prompts = get_persona_agent_prompt_expressions( + agent_a_config.search_request.persona + ) history = build_history_prompt(agent_a_config, question) @@ -134,7 +138,6 @@ def generate_initial_answer( ) else: - decomp_answer_results = state.decomp_answer_results good_qa_list: list[str] = [] @@ -173,7 +176,9 @@ def generate_initial_answer( # summarize the history iff too long if len(history) > AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH: - history = summarize_history(history, question, persona.persona_base, model) + history = summarize_history( + history, question, persona_prompts.base_prompt, model + ) doc_context = format_docs(relevant_docs) doc_context = trim_prompt_piece( @@ -181,7 +186,7 @@ def generate_initial_answer( doc_context, base_prompt + sub_question_answer_str - + persona.persona_prompt + + persona_prompts.contextualized_prompt + history + date_str, ) @@ -194,7 +199,7 @@ def generate_initial_answer( sub_question_answer_str ), relevant_docs=format_docs(relevant_docs), - persona_specification=persona.persona_prompt, + persona_specification=persona_prompts.contextualized_prompt, history=history, date_prompt=date_str, ) diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/direct_llm_handling.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/direct_llm_handling.py index d94d631ffb55..1e244c8d8c98 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/direct_llm_handling.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/direct_llm_handling.py @@ -12,7 +12,9 @@ from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpda from onyx.agents.agent_search.deep_search_a.main.states import MainState from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT -from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions +from onyx.agents.agent_search.shared_graph_utils.utils import ( + get_persona_agent_prompt_expressions, +) from onyx.chat.models import AgentAnswerPiece @@ -23,7 +25,9 @@ def direct_llm_handling( agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) question = agent_a_config.search_request.query - persona = get_persona_expressions(agent_a_config.search_request.persona) + persona_contextualialized_prompt = get_persona_agent_prompt_expressions( + agent_a_config.search_request.persona + ).contextualized_prompt logger.info(f"--------{now_start}--------LLM HANDLING START---") @@ -32,7 +36,8 @@ def direct_llm_handling( msg = [ HumanMessage( content=DIRECT_LLM_PROMPT.format( - persona_specification=persona.persona_prompt, question=question + persona_specification=persona_contextualialized_prompt, + question=question, ) ) ] diff --git a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py index c63ad6cc0b60..ed10af303670 100644 --- a/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py +++ b/backend/onyx/agents/agent_search/deep_search_a/main/nodes/generate_refined_answer.py @@ -39,7 +39,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import ( dispatch_main_answer_stop_info, ) from onyx.agents.agent_search.shared_graph_utils.utils import format_docs -from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions +from onyx.agents.agent_search.shared_graph_utils.utils import ( + get_persona_agent_prompt_expressions, +) from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id from onyx.chat.models import AgentAnswerPiece @@ -58,7 +60,9 @@ def generate_refined_answer( agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"]) question = agent_a_config.search_request.query - persona = get_persona_expressions(agent_a_config.search_request.persona) + persona_contextualized_prompt = get_persona_agent_prompt_expressions( + agent_a_config.search_request.persona + ).contextualized_prompt history = build_history_prompt(agent_a_config, question) date_str = get_today_prompt() @@ -188,7 +192,7 @@ def generate_refined_answer( + question + sub_question_answer_str + initial_answer - + persona.persona_prompt + + persona_contextualized_prompt + history, ) @@ -202,7 +206,7 @@ def generate_refined_answer( ), relevant_docs=relevant_docs, initial_answer=remove_document_citations(initial_answer), - persona_specification=persona.persona_prompt, + persona_specification=persona_contextualized_prompt, date_prompt=date_str, ) ) diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/agent_prompt_ops.py b/backend/onyx/agents/agent_search/shared_graph_utils/agent_prompt_ops.py index 0366d988c6b5..0d66f395ce04 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/agent_prompt_ops.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/agent_prompt_ops.py @@ -6,7 +6,9 @@ from langchain_core.messages.tool import ToolMessage from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2 from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT -from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions +from onyx.agents.agent_search.shared_graph_utils.utils import ( + get_persona_agent_prompt_expressions, +) from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH @@ -79,7 +81,9 @@ def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) - def build_history_prompt(config: AgentSearchConfig, question: str) -> str: prompt_builder = config.prompt_builder model = config.fast_llm - persona_base = get_persona_expressions(config.search_request.persona).persona_base + persona_base = get_persona_agent_prompt_expressions( + config.search_request.persona + ).base_prompt if prompt_builder is None: return "" diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/models.py b/backend/onyx/agents/agent_search/shared_graph_utils/models.py index b548289c9498..ccf2ce6dd460 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/models.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/models.py @@ -115,5 +115,5 @@ class CombinedAgentMetrics(BaseModel): class PersonaExpressions(BaseModel): - persona_prompt: str - persona_base: str + contextualized_prompt: str + base_prompt: str diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index 743425ac6cf3..f9cf6dd18baa 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -257,7 +257,7 @@ def get_test_config( return config, search_tool -def get_persona_expressions(persona: Persona | None) -> PersonaExpressions: +def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions: if persona is None: persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT persona_base = "" @@ -267,7 +267,9 @@ def get_persona_expressions(persona: Persona | None) -> PersonaExpressions: persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format( persona_prompt=persona_base ) - return PersonaExpressions(persona_prompt=persona_prompt, persona_base=persona_base) + return PersonaExpressions( + contextualized_prompt=persona_prompt, base_prompt=persona_base + ) def make_question_id(level: int, question_nr: int) -> str: