mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-18 11:34:12 +02:00
persona_prompt improvements
This commit is contained in:
committed by
Evan Lohn
parent
4817fa0bd1
commit
6bef5ca7a4
@@ -18,7 +18,9 @@ from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import NO_RECOVERED_DOCS
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
@@ -40,7 +42,9 @@ def answer_generation(
|
||||
state.documents
|
||||
level, question_nr = parse_question_id(state.question_id)
|
||||
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
|
||||
persona = get_persona_expressions(agent_search_config.search_request.persona)
|
||||
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
||||
agent_search_config.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
if len(context_docs) == 0:
|
||||
answer_str = NO_RECOVERED_DOCS
|
||||
@@ -61,7 +65,7 @@ def answer_generation(
|
||||
question=question,
|
||||
original_question=agent_search_config.search_request.query,
|
||||
docs=context_docs,
|
||||
persona_specification=persona.persona_prompt,
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
|
@@ -43,7 +43,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
|
||||
@@ -65,7 +67,9 @@ def generate_initial_answer(
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona = get_persona_expressions(agent_a_config.search_request.persona)
|
||||
persona_prompts = get_persona_agent_prompt_expressions(
|
||||
agent_a_config.search_request.persona
|
||||
)
|
||||
|
||||
history = build_history_prompt(agent_a_config, question)
|
||||
|
||||
@@ -134,7 +138,6 @@ def generate_initial_answer(
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
decomp_answer_results = state.decomp_answer_results
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
@@ -173,7 +176,9 @@ def generate_initial_answer(
|
||||
|
||||
# summarize the history iff too long
|
||||
if len(history) > AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH:
|
||||
history = summarize_history(history, question, persona.persona_base, model)
|
||||
history = summarize_history(
|
||||
history, question, persona_prompts.base_prompt, model
|
||||
)
|
||||
|
||||
doc_context = format_docs(relevant_docs)
|
||||
doc_context = trim_prompt_piece(
|
||||
@@ -181,7 +186,7 @@ def generate_initial_answer(
|
||||
doc_context,
|
||||
base_prompt
|
||||
+ sub_question_answer_str
|
||||
+ persona.persona_prompt
|
||||
+ persona_prompts.contextualized_prompt
|
||||
+ history
|
||||
+ date_str,
|
||||
)
|
||||
@@ -194,7 +199,7 @@ def generate_initial_answer(
|
||||
sub_question_answer_str
|
||||
),
|
||||
relevant_docs=format_docs(relevant_docs),
|
||||
persona_specification=persona.persona_prompt,
|
||||
persona_specification=persona_prompts.contextualized_prompt,
|
||||
history=history,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
|
@@ -12,7 +12,9 @@ from onyx.agents.agent_search.deep_search_a.main.states import InitialAnswerUpda
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainState
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import DIRECT_LLM_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
|
||||
|
||||
@@ -23,7 +25,9 @@ def direct_llm_handling(
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona = get_persona_expressions(agent_a_config.search_request.persona)
|
||||
persona_contextualialized_prompt = get_persona_agent_prompt_expressions(
|
||||
agent_a_config.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
logger.info(f"--------{now_start}--------LLM HANDLING START---")
|
||||
|
||||
@@ -32,7 +36,8 @@ def direct_llm_handling(
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DIRECT_LLM_PROMPT.format(
|
||||
persona_specification=persona.persona_prompt, question=question
|
||||
persona_specification=persona_contextualialized_prompt,
|
||||
question=question,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
@@ -39,7 +39,9 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -58,7 +60,9 @@ def generate_refined_answer(
|
||||
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
question = agent_a_config.search_request.query
|
||||
persona = get_persona_expressions(agent_a_config.search_request.persona)
|
||||
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
|
||||
agent_a_config.search_request.persona
|
||||
).contextualized_prompt
|
||||
|
||||
history = build_history_prompt(agent_a_config, question)
|
||||
date_str = get_today_prompt()
|
||||
@@ -188,7 +192,7 @@ def generate_refined_answer(
|
||||
+ question
|
||||
+ sub_question_answer_str
|
||||
+ initial_answer
|
||||
+ persona.persona_prompt
|
||||
+ persona_contextualized_prompt
|
||||
+ history,
|
||||
)
|
||||
|
||||
@@ -202,7 +206,7 @@ def generate_refined_answer(
|
||||
),
|
||||
relevant_docs=relevant_docs,
|
||||
initial_answer=remove_document_citations(initial_answer),
|
||||
persona_specification=persona.persona_prompt,
|
||||
persona_specification=persona_contextualized_prompt,
|
||||
date_prompt=date_str,
|
||||
)
|
||||
)
|
||||
|
@@ -6,7 +6,9 @@ from langchain_core.messages.tool import ToolMessage
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT_v2
|
||||
from onyx.agents.agent_search.shared_graph_utils.prompts import HISTORY_PROMPT
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_persona_expressions
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_persona_agent_prompt_expressions,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_today_prompt
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_CHAR_LENGTH
|
||||
@@ -79,7 +81,9 @@ def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -
|
||||
def build_history_prompt(config: AgentSearchConfig, question: str) -> str:
|
||||
prompt_builder = config.prompt_builder
|
||||
model = config.fast_llm
|
||||
persona_base = get_persona_expressions(config.search_request.persona).persona_base
|
||||
persona_base = get_persona_agent_prompt_expressions(
|
||||
config.search_request.persona
|
||||
).base_prompt
|
||||
|
||||
if prompt_builder is None:
|
||||
return ""
|
||||
|
@@ -115,5 +115,5 @@ class CombinedAgentMetrics(BaseModel):
|
||||
|
||||
|
||||
class PersonaExpressions(BaseModel):
|
||||
persona_prompt: str
|
||||
persona_base: str
|
||||
contextualized_prompt: str
|
||||
base_prompt: str
|
||||
|
@@ -257,7 +257,7 @@ def get_test_config(
|
||||
return config, search_tool
|
||||
|
||||
|
||||
def get_persona_expressions(persona: Persona | None) -> PersonaExpressions:
|
||||
def get_persona_agent_prompt_expressions(persona: Persona | None) -> PersonaExpressions:
|
||||
if persona is None:
|
||||
persona_prompt = ASSISTANT_SYSTEM_PROMPT_DEFAULT
|
||||
persona_base = ""
|
||||
@@ -267,7 +267,9 @@ def get_persona_expressions(persona: Persona | None) -> PersonaExpressions:
|
||||
persona_prompt = ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
|
||||
persona_prompt=persona_base
|
||||
)
|
||||
return PersonaExpressions(persona_prompt=persona_prompt, persona_base=persona_base)
|
||||
return PersonaExpressions(
|
||||
contextualized_prompt=persona_prompt, base_prompt=persona_base
|
||||
)
|
||||
|
||||
|
||||
def make_question_id(level: int, question_nr: int) -> str:
|
||||
|
Reference in New Issue
Block a user