persona_prompt improvements

This commit is contained in:
joachim-danswer
2025-01-29 07:57:22 -08:00
committed by Evan Lohn
parent 4817fa0bd1
commit 6bef5ca7a4
7 changed files with 46 additions and 22 deletions

View File

@@ -18,7 +18,9 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
)
from onyx.agents.agent_search.shared_graph_utils.prompts import NO_RECOVERED_DOCS
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
@@ -40,7 +42,9 @@ def answer_generation(
state.documents
level, question_nr = parse_question_id(state.question_id)
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
persona = get_persona_expressions(agent_search_config.search_request.persona)
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
agent_search_config.search_request.persona
).contextualized_prompt
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
@@ -61,7 +65,7 @@ def answer_generation(
question=question,
original_question=agent_search_config.search_request.query,
docs=context_docs,
persona_specification=persona.persona_prompt,
persona_specification=persona_contextualized_prompt,
config=fast_llm.config,
)

View File

@@ -43,7 +43,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
@@ -65,7 +67,9 @@ def generate_initial_answer(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona = get_persona_expressions(agent_a_config.search_request.persona)
persona_prompts = get_persona_agent_prompt_expressions(
agent_a_config.search_request.persona
)
history = build_history_prompt(agent_a_config, question)
@@ -134,7 +138,6 @@ def generate_initial_answer(
)
else:
decomp_answer_results = state.decomp_answer_results
good_qa_list: list[str] = []
@@ -173,7 +176,9 @@ def generate_initial_answer(
# summarize the history iff too long
if len(history) > AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH:
history = summarize_history(history, question, persona.persona_base, model)
history = summarize_history(
history, question, persona_prompts.base_prompt, model
)
doc_context = format_docs(relevant_docs)
doc_context = trim_prompt_piece(
@@ -181,7 +186,7 @@ def generate_initial_answer(
doc_context,
base_prompt
+ sub_question_answer_str
+ persona.persona_prompt
+ persona_prompts.contextualized_prompt
+ history
+ date_str,
)
@@ -194,7 +199,7 @@ def generate_initial_answer(
sub_question_answer_str
),
relevant_docs=format_docs(relevant_docs),
persona_specification=persona.persona_prompt,
persona_specification=persona_prompts.contextualized_prompt,
history=history,
date_prompt=date_str,
)

View File

@@ -12,7 +12,9 @@ from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpda
from onyx.agents.agent_search.deep_search_a.main.states import MainState
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.chat.models import AgentAnswerPiece
@@ -23,7 +25,9 @@ def direct_llm_handling(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona = get_persona_expressions(agent_a_config.search_request.persona)
persona_contextualialized_prompt = get_persona_agent_prompt_expressions(
agent_a_config.search_request.persona
).contextualized_prompt
logger.info(f"--------{now_start}--------LLM HANDLING START---")
@@ -32,7 +36,8 @@ def direct_llm_handling(
msg = [
HumanMessage(
content=DIRECT_LLM_PROMPT.format(
persona_specification=persona.persona_prompt, question=question
persona_specification=persona_contextualialized_prompt,
question=question,
)
)
]

View File

@@ -39,7 +39,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.chat.models import AgentAnswerPiece
@@ -58,7 +60,9 @@ def generate_refined_answer(
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
question = agent_a_config.search_request.query
persona = get_persona_expressions(agent_a_config.search_request.persona)
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
agent_a_config.search_request.persona
).contextualized_prompt
history = build_history_prompt(agent_a_config, question)
date_str = get_today_prompt()
@@ -188,7 +192,7 @@ def generate_refined_answer(
+ question
+ sub_question_answer_str
+ initial_answer
+ persona.persona_prompt
+ persona_contextualized_prompt
+ history,
)
@@ -202,7 +206,7 @@ def generate_refined_answer(
),
relevant_docs=relevant_docs,
initial_answer=remove_document_citations(initial_answer),
persona_specification=persona.persona_prompt,
persona_specification=persona_contextualized_prompt,
date_prompt=date_str,
)
)

View File

@@ -6,7 +6,9 @@ from langchain_core.messages.tool import ToolMessage
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH
@@ -79,7 +81,9 @@ def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -
def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
prompt_builder = config.prompt_builder
model = config.fast_llm
persona_base = get_persona_expressions(config.search_request.persona).persona_base
persona_base = get_persona_agent_prompt_expressions(
config.search_request.persona
).base_prompt
if prompt_builder is None:
return ""

View File

@@ -115,5 +115,5 @@ class CombinedAgentMetrics(BaseModel):
class PersonaExpressions(BaseModel):
persona_prompt: str
persona_base: str
contextualized_prompt: str
base_prompt: str

View File

@@ -257,7 +257,7 @@ def get_test_config(
return config, search_tool
def get_persona_expressions(persona: Persona | None) -> PersonaExpressions:
def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions:
if persona is None:
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
persona_base = ""
@@ -267,7 +267,9 @@ def get_persona_expressions(persona: Persona | None) -> PersonaExpressions:
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=persona_base
)
return PersonaExpressions(persona_prompt=persona_prompt, persona_base=persona_base)
return PersonaExpressions(
contextualized_prompt=persona_prompt, base_prompt=persona_base
)
def make_question_id(level: int, question_nr: int) -> str: