Add ability to specify persona in API request (#2302)

* persona

* all prepared excluding configuration

* more sensical model structure

* update tstream

* type updates

* rm

* quick and simple updates

* minor updates

* te

* ensure typing + naming

* remove old todo + rebase update

* remove unnecessary check
This commit is contained in:
pablodanswer 2024-09-16 14:31:01 -07:00 committed by GitHub
parent df464fc54b
commit 2dd3870504
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 283 additions and 64 deletions

View File

@ -0,0 +1,31 @@
"""add nullable to persona id in Chat Session
Revision ID: c99d76fcd298
Revises: 5c7fdadae813
Create Date: 2024-07-09 19:27:01.579697
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c99d76fcd298"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
)
def downgrade() -> None:
op.alter_column(
"chat_session",
"persona_id",
existing_type=sa.INTEGER(),
nullable=False,
)

View File

@ -675,9 +675,11 @@ def stream_chat_message_objects(
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
dedupe_docs=(
retrieval_options.dedupe_docs
if retrieval_options
else False
),
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
@ -786,16 +788,18 @@ def stream_chat_message_objects(
if message_specific_citations
else None,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else [],
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else []
),
)
logger.debug("Committing messages")

View File

@ -5,6 +5,7 @@ from typing import cast
from typing import Optional
from typing import TypeVar
from fastapi import HTTPException
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
@ -153,15 +154,23 @@ def handle_regular_answer(
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
if new_message_request.persona_config:
raise HTTPException(
status_code=403,
detail="Slack bot does not support persona config",
)
elif new_message_request.persona_id:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context

View File

@ -226,7 +226,7 @@ def create_chat_session(
db_session: Session,
description: str,
user_id: UUID | None,
persona_id: int,
persona_id: int | None, # Can be none if temporary persona is used
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,

View File

@ -4,9 +4,11 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import DocumentSet
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import SearchSettings
from danswer.db.models import Tool as ToolModel
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
@ -103,6 +105,20 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_doc_sets(
db_session: Session, doc_ids: list[int]
) -> list[DocumentSet]:
return list(
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
)
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
return list(
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
)
def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,

View File

@ -866,7 +866,9 @@ class ChatSession(Base):
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
description: Mapped[str] = mapped_column(Text)
# One-shot direct answering, currently the two types of chats are not mixed
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
@ -900,7 +902,6 @@ class ChatSession(Base):
prompt_override: Mapped[PromptOverride | None] = mapped_column(
PydanticType(PromptOverride), nullable=True
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@ -909,7 +910,6 @@ class ChatSession(Base):
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions"

View File

@ -563,13 +563,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return prompts
return list(prompts)
def get_prompt_by_id(

View File

@ -26,6 +26,7 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.chat import update_search_docs_table_with_relevance
from danswer.db.engine import get_session_context_manager
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.persona import get_prompt_by_id
from danswer.llm.answering.answer import Answer
@ -60,7 +61,7 @@ from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
logger = setup_logger()
@ -118,7 +119,17 @@ def stream_answer_objects(
one_shot=True,
danswerbot_flow=danswerbot_flow,
)
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
temporary_persona: Persona | None = None
if query_req.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session, persona_config=query_req.persona_config, user=user
)
temporary_persona = new_persona
persona = temporary_persona if temporary_persona else chat_session.persona
llm, fast_llm = get_llms_for_persona(persona=persona)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
@ -153,11 +164,11 @@ def stream_answer_objects(
prompt_id=query_req.prompt_id, user=None, db_session=db_session
)
if prompt is None:
if not chat_session.persona.prompts:
if not persona.prompts:
raise RuntimeError(
"Persona does not have any prompts - this should never happen"
)
prompt = chat_session.persona.prompts[0]
prompt = persona.prompts[0]
# Create the first User query message
new_user_message = create_new_chat_message(
@ -174,9 +185,7 @@ def stream_answer_objects(
prompt_config = PromptConfig.from_model(prompt)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
chat_session.persona.num_chunks
if chat_session.persona.num_chunks is not None
else default_num_chunks
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
),
max_tokens=max_document_tokens,
)
@ -187,16 +196,16 @@ def stream_answer_objects(
evaluation_type=LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type,
persona=chat_session.persona,
persona=persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
bypass_acl=bypass_acl,
chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
bypass_acl=bypass_acl,
)
answer_config = AnswerStyleConfig(
@ -209,13 +218,15 @@ def stream_answer_objects(
question=query_msg.message,
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
force_use=True,
tool_name=search_tool.name,
args={"query": rephrased_query},
tools=[search_tool] if search_tool else [],
force_use_tool=(
ForceUseTool(
tool_name=search_tool.name,
args={"query": rephrased_query},
force_use=True,
)
),
# for now, don't use tool calling for this flow, as we haven't
# tested quotes with tool calling too much yet
@ -223,9 +234,7 @@ def stream_answer_objects(
return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
)
# won't be any ImageGenerationDisplay responses since that tool is never passed in
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):
@ -261,6 +270,7 @@ def stream_answer_objects(
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
)
yield initial_response
elif packet.id == SEARCH_DOC_CONTENT_ID:
@ -287,6 +297,7 @@ def stream_answer_objects(
relevance_summary=evaluation_response,
)
yield evaluation_response
else:
yield packet

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
@ -8,6 +10,8 @@ from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
@ -23,10 +27,49 @@ class ThreadMessage(BaseModel):
role: MessageType = MessageType.USER
class PromptConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class DocumentSetConfig(BaseModel):
id: int
class ToolConfig(BaseModel):
id: int
class PersonaConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
class DirectQARequest(ChunkContext):
persona_config: PersonaConfig | None = None
persona_id: int | None = None
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int
prompt_id: int | None = None
multilingual_query_expansion: list[str] | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
@ -43,6 +86,12 @@ class DirectQARequest(ChunkContext):
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "DirectQARequest":
if (self.persona_config is None) == (self.persona_id is None):
raise ValueError("Exactly one of persona_config or persona_id must be set")
return self
@model_validator(mode="after")
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
if self.chain_of_thought and self.prompt_id is not None:

View File

@ -164,7 +164,7 @@ def get_chat_session(
chat_session_id=session_id,
description=chat_session.description,
persona_id=chat_session.persona_id,
persona_name=chat_session.persona.name,
persona_name=chat_session.persona.name if chat_session.persona else None,
current_alternate_model=chat_session.current_alternate_model,
messages=[
translate_db_message_to_chat_message_detail(

View File

@ -136,7 +136,7 @@ class RenameChatSessionResponse(BaseModel):
class ChatSessionDetails(BaseModel):
id: int
name: str
persona_id: int
persona_id: int | None = None
time_created: str
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
@ -196,8 +196,8 @@ class SearchSessionDetailResponse(BaseModel):
class ChatSessionDetailResponse(BaseModel):
chat_session_id: int
description: str
persona_id: int
persona_name: str
persona_id: int | None = None
persona_name: str | None
messages: list[ChatMessageDetail]
time_created: datetime
shared_status: ChatSessionSharedStatus

View File

@ -32,6 +32,7 @@ from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import (
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
logger = setup_logger()
@ -133,12 +134,23 @@ def get_answer_with_quote(
query = query_request.messages[0].message
logger.notice(f"Received query for one shot answer API with quotes: {query}")
persona = get_persona_by_id(
persona_id=query_request.persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
if query_request.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session,
persona_config=query_request.persona_config,
user=user,
)
persona = new_persona
elif query_request.persona_id is not None:
persona = get_persona_by_id(
persona_id=query_request.persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
else:
raise KeyError("Must provide persona ID or Persona Config")
llm = get_main_llm_from_tuple(
get_default_llms() if not persona else get_llms_for_persona(persona)

View File

@ -0,0 +1,83 @@
from typing import cast
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import is_user_admin
from danswer.db.llm import fetch_existing_doc_sets
from danswer.db.llm import fetch_existing_tools
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.one_shot_answer.models import PersonaConfig
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
def create_temporary_persona(
persona_config: PersonaConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
persona.prompts = [
Prompt(
name=p.name,
description=p.description,
system_prompt=p.system_prompt,
task_prompt=p.task_prompt,
include_citations=p.include_citations,
datetime_aware=p.datetime_aware,
)
for p in persona_config.prompts
]
elif persona_config.prompt_ids:
persona.prompts = get_prompts_by_ids(
db_session=db_session, prompt_ids=persona_config.prompt_ids
)
persona.tools = []
if persona_config.custom_tools_openapi:
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema(schema),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona

View File

@ -87,7 +87,7 @@ class ChatSessionMinimal(BaseModel):
name: str | None
first_user_message: str
first_ai_message: str
persona_name: str
persona_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | Literal["mixed"] | None
@ -97,7 +97,7 @@ class ChatSessionSnapshot(BaseModel):
user_email: str
name: str | None
messages: list[MessageSnapshot]
persona_name: str
persona_name: str | None
time_created: datetime
@ -111,7 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str
persona_name: str | None
user_email: str
time_created: datetime
@ -145,7 +145,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str]:
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
@ -235,7 +235,9 @@ def fetch_and_process_chat_session_history_minimal(
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
persona_name=chat_session.persona.name,
persona_name=chat_session.persona.name
if chat_session.persona
else None,
time_created=chat_session.time_created,
feedback_type=feedback_type,
)
@ -300,7 +302,7 @@ def snapshot_from_chat_session(
for message in messages
if message.message_type != MessageType.SYSTEM
],
persona_name=chat_session.persona.name,
persona_name=chat_session.persona.name if chat_session.persona else None,
time_created=chat_session.time_created,
)