mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Add ability to specify persona in API request (#2302)
* persona * all prepared excluding configuration * more sensical model structure * update tstream * type updates * rm * quick and simple updates * minor updates * te * ensure typing + naming * remove old todo + rebase update * remove unnecessary check
This commit is contained in:
@@ -32,6 +32,7 @@ from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
||||
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
|
||||
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
|
||||
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -133,12 +134,23 @@ def get_answer_with_quote(
|
||||
query = query_request.messages[0].message
|
||||
logger.notice(f"Received query for one shot answer API with quotes: {query}")
|
||||
|
||||
persona = get_persona_by_id(
|
||||
persona_id=query_request.persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
if query_request.persona_config is not None:
|
||||
new_persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=query_request.persona_config,
|
||||
user=user,
|
||||
)
|
||||
persona = new_persona
|
||||
|
||||
elif query_request.persona_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=query_request.persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
else:
|
||||
raise KeyError("Must provide persona ID or Persona Config")
|
||||
|
||||
llm = get_main_llm_from_tuple(
|
||||
get_default_llms() if not persona else get_llms_for_persona(persona)
|
||||
|
83
backend/ee/danswer/server/query_and_chat/utils.py
Normal file
83
backend/ee/danswer/server/query_and_chat/utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.db.llm import fetch_existing_doc_sets
|
||||
from danswer.db.llm import fetch_existing_tools
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.one_shot_answer.models import PersonaConfig
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(schema),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
@@ -87,7 +87,7 @@ class ChatSessionMinimal(BaseModel):
|
||||
name: str | None
|
||||
first_user_message: str
|
||||
first_ai_message: str
|
||||
persona_name: str
|
||||
persona_name: str | None
|
||||
time_created: datetime
|
||||
feedback_type: QAFeedbackType | Literal["mixed"] | None
|
||||
|
||||
@@ -97,7 +97,7 @@ class ChatSessionSnapshot(BaseModel):
|
||||
user_email: str
|
||||
name: str | None
|
||||
messages: list[MessageSnapshot]
|
||||
persona_name: str
|
||||
persona_name: str | None
|
||||
time_created: datetime
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
|
||||
retrieved_documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
persona_name: str
|
||||
persona_name: str | None
|
||||
user_email: str
|
||||
time_created: datetime
|
||||
|
||||
@@ -145,7 +145,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
|
||||
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
||||
]
|
||||
|
||||
def to_json(self) -> dict[str, str]:
|
||||
def to_json(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"chat_session_id": str(self.chat_session_id),
|
||||
"message_pair_num": str(self.message_pair_num),
|
||||
@@ -235,7 +235,9 @@ def fetch_and_process_chat_session_history_minimal(
|
||||
name=chat_session.description,
|
||||
first_user_message=first_user_message,
|
||||
first_ai_message=first_ai_message,
|
||||
persona_name=chat_session.persona.name,
|
||||
persona_name=chat_session.persona.name
|
||||
if chat_session.persona
|
||||
else None,
|
||||
time_created=chat_session.time_created,
|
||||
feedback_type=feedback_type,
|
||||
)
|
||||
@@ -300,7 +302,7 @@ def snapshot_from_chat_session(
|
||||
for message in messages
|
||||
if message.message_type != MessageType.SYSTEM
|
||||
],
|
||||
persona_name=chat_session.persona.name,
|
||||
persona_name=chat_session.persona.name if chat_session.persona else None,
|
||||
time_created=chat_session.time_created,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user