Add ability to specify persona in API request (#2302)

* persona

* all prepared excluding configuration

* more sensical model structure

* update tstream

* type updates

* rm

* quick and simple updates

* minor updates

* te

* ensure typing + naming

* remove old todo + rebase update

* remove unnecessary check
This commit is contained in:
pablodanswer
2024-09-16 14:31:01 -07:00
committed by GitHub
parent df464fc54b
commit 2dd3870504
14 changed files with 283 additions and 64 deletions

View File

@@ -32,6 +32,7 @@ from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import (
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
logger = setup_logger()
@@ -133,12 +134,23 @@ def get_answer_with_quote(
query = query_request.messages[0].message
logger.notice(f"Received query for one shot answer API with quotes: {query}")
persona = get_persona_by_id(
persona_id=query_request.persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
if query_request.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session,
persona_config=query_request.persona_config,
user=user,
)
persona = new_persona
elif query_request.persona_id is not None:
persona = get_persona_by_id(
persona_id=query_request.persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
else:
raise KeyError("Must provide persona ID or Persona Config")
llm = get_main_llm_from_tuple(
get_default_llms() if not persona else get_llms_for_persona(persona)

View File

@@ -0,0 +1,83 @@
from typing import cast
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import is_user_admin
from danswer.db.llm import fetch_existing_doc_sets
from danswer.db.llm import fetch_existing_tools
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.one_shot_answer.models import PersonaConfig
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
def create_temporary_persona(
persona_config: PersonaConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
persona.prompts = [
Prompt(
name=p.name,
description=p.description,
system_prompt=p.system_prompt,
task_prompt=p.task_prompt,
include_citations=p.include_citations,
datetime_aware=p.datetime_aware,
)
for p in persona_config.prompts
]
elif persona_config.prompt_ids:
persona.prompts = get_prompts_by_ids(
db_session=db_session, prompt_ids=persona_config.prompt_ids
)
persona.tools = []
if persona_config.custom_tools_openapi:
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema(schema),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona

View File

@@ -87,7 +87,7 @@ class ChatSessionMinimal(BaseModel):
name: str | None
first_user_message: str
first_ai_message: str
persona_name: str
persona_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | Literal["mixed"] | None
@@ -97,7 +97,7 @@ class ChatSessionSnapshot(BaseModel):
user_email: str
name: str | None
messages: list[MessageSnapshot]
persona_name: str
persona_name: str | None
time_created: datetime
@@ -111,7 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str
persona_name: str | None
user_email: str
time_created: datetime
@@ -145,7 +145,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str]:
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
@@ -235,7 +235,9 @@ def fetch_and_process_chat_session_history_minimal(
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
persona_name=chat_session.persona.name,
persona_name=chat_session.persona.name
if chat_session.persona
else None,
time_created=chat_session.time_created,
feedback_type=feedback_type,
)
@@ -300,7 +302,7 @@ def snapshot_from_chat_session(
for message in messages
if message.message_type != MessageType.SYSTEM
],
persona_name=chat_session.persona.name,
persona_name=chat_session.persona.name if chat_session.persona else None,
time_created=chat_session.time_created,
)