Add ability to specify persona in API request (#2302)

* persona

* all prepared excluding configuration

* more sensical model structure

* update tstream

* type updates

* rm

* quick and simple updates

* minor updates

* te

* ensure typing + naming

* remove old todo + rebase update

* remove unnecessary check
This commit is contained in:
pablodanswer 2024-09-16 14:31:01 -07:00 committed by GitHub
parent df464fc54b
commit 2dd3870504
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 283 additions and 64 deletions

View File

@ -0,0 +1,31 @@
"""add nullable to persona id in Chat Session
Revision ID: c99d76fcd298
Revises: 5c7fdadae813
Create Date: 2024-07-09 19:27:01.579697
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c99d76fcd298"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
)
def downgrade() -> None:
op.alter_column(
"chat_session",
"persona_id",
existing_type=sa.INTEGER(),
nullable=False,
)

View File

@ -675,9 +675,11 @@ def stream_chat_message_objects(
db_session=db_session, db_session=db_session,
selected_search_docs=selected_db_search_docs, selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on # Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs dedupe_docs=(
if retrieval_options retrieval_options.dedupe_docs
else False, if retrieval_options
else False
),
) )
yield qa_docs_response yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID: elif packet.id == SECTION_RELEVANCE_LIST_ID:
@ -786,16 +788,18 @@ def stream_chat_message_objects(
if message_specific_citations if message_specific_citations
else None, else None,
error=None, error=None,
tool_calls=[ tool_calls=(
ToolCall( [
tool_id=tool_name_to_tool_id[tool_result.tool_name], ToolCall(
tool_name=tool_result.tool_name, tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_arguments=tool_result.tool_args, tool_name=tool_result.tool_name,
tool_result=tool_result.tool_result, tool_arguments=tool_result.tool_args,
) tool_result=tool_result.tool_result,
] )
if tool_result ]
else [], if tool_result
else []
),
) )
logger.debug("Committing messages") logger.debug("Committing messages")

View File

@ -5,6 +5,7 @@ from typing import cast
from typing import Optional from typing import Optional
from typing import TypeVar from typing import TypeVar
from fastapi import HTTPException
from retry import retry from retry import retry
from slack_sdk import WebClient from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock from slack_sdk.models.blocks import DividerBlock
@ -153,15 +154,23 @@ def handle_regular_answer(
with Session(get_sqlalchemy_engine()) as db_session: with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1: if len(new_message_request.messages) > 1:
persona = cast( if new_message_request.persona_config:
Persona, raise HTTPException(
fetch_persona_by_id( status_code=403,
db_session, detail="Slack bot does not support persona config",
new_message_request.persona_id, )
user=None,
get_editable=False, elif new_message_request.persona_id:
), persona = cast(
) Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
llm, _ = get_llms_for_persona(persona) llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context # In cases of threads, split the available tokens between docs and thread context

View File

@ -226,7 +226,7 @@ def create_chat_session(
db_session: Session, db_session: Session,
description: str, description: str,
user_id: UUID | None, user_id: UUID | None,
persona_id: int, persona_id: int | None, # Can be none if temporary persona is used
llm_override: LLMOverride | None = None, llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None, prompt_override: PromptOverride | None = None,
one_shot: bool = False, one_shot: bool = False,

View File

@ -4,9 +4,11 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import DocumentSet
from danswer.db.models import LLMProvider as LLMProviderModel from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import SearchSettings from danswer.db.models import SearchSettings
from danswer.db.models import Tool as ToolModel
from danswer.db.models import User from danswer.db.models import User
from danswer.db.models import User__UserGroup from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider from danswer.server.manage.embedding.models import CloudEmbeddingProvider
@ -103,6 +105,20 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all()) return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_doc_sets(
db_session: Session, doc_ids: list[int]
) -> list[DocumentSet]:
return list(
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
)
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
return list(
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
)
def fetch_existing_llm_providers( def fetch_existing_llm_providers(
db_session: Session, db_session: Session,
user: User | None = None, user: User | None = None,

View File

@ -866,7 +866,9 @@ class ChatSession(Base):
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id")) persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
description: Mapped[str] = mapped_column(Text) description: Mapped[str] = mapped_column(Text)
# One-shot direct answering, currently the two types of chats are not mixed # One-shot direct answering, currently the two types of chats are not mixed
one_shot: Mapped[bool] = mapped_column(Boolean, default=False) one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
@ -900,7 +902,6 @@ class ChatSession(Base):
prompt_override: Mapped[PromptOverride | None] = mapped_column( prompt_override: Mapped[PromptOverride | None] = mapped_column(
PydanticType(PromptOverride), nullable=True PydanticType(PromptOverride), nullable=True
) )
time_updated: Mapped[datetime.datetime] = mapped_column( time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), DateTime(timezone=True),
server_default=func.now(), server_default=func.now(),
@ -909,7 +910,6 @@ class ChatSession(Base):
time_created: Mapped[datetime.datetime] = mapped_column( time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now() DateTime(timezone=True), server_default=func.now()
) )
user: Mapped[User] = relationship("User", back_populates="chat_sessions") user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship( folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions" "ChatFolder", back_populates="chat_sessions"

View File

@ -563,13 +563,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
) )
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]: def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users""" """Unsafe, can fetch prompts from all users"""
if not prompt_ids: if not prompt_ids:
return [] return []
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all() prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return prompts return list(prompts)
def get_prompt_by_id( def get_prompt_by_id(

View File

@ -26,6 +26,7 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.chat import update_search_docs_table_with_relevance from danswer.db.chat import update_search_docs_table_with_relevance
from danswer.db.engine import get_session_context_manager from danswer.db.engine import get_session_context_manager
from danswer.db.models import Persona
from danswer.db.models import User from danswer.db.models import User
from danswer.db.persona import get_prompt_by_id from danswer.db.persona import get_prompt_by_id
from danswer.llm.answering.answer import Answer from danswer.llm.answering.answer import Answer
@ -60,7 +61,7 @@ from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallKickoff from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time from danswer.utils.timing import log_generator_function_time
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
logger = setup_logger() logger = setup_logger()
@ -118,7 +119,17 @@ def stream_answer_objects(
one_shot=True, one_shot=True,
danswerbot_flow=danswerbot_flow, danswerbot_flow=danswerbot_flow,
) )
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
temporary_persona: Persona | None = None
if query_req.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session, persona_config=query_req.persona_config, user=user
)
temporary_persona = new_persona
persona = temporary_persona if temporary_persona else chat_session.persona
llm, fast_llm = get_llms_for_persona(persona=persona)
llm_tokenizer = get_tokenizer( llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name, model_name=llm.config.model_name,
@ -153,11 +164,11 @@ def stream_answer_objects(
prompt_id=query_req.prompt_id, user=None, db_session=db_session prompt_id=query_req.prompt_id, user=None, db_session=db_session
) )
if prompt is None: if prompt is None:
if not chat_session.persona.prompts: if not persona.prompts:
raise RuntimeError( raise RuntimeError(
"Persona does not have any prompts - this should never happen" "Persona does not have any prompts - this should never happen"
) )
prompt = chat_session.persona.prompts[0] prompt = persona.prompts[0]
# Create the first User query message # Create the first User query message
new_user_message = create_new_chat_message( new_user_message = create_new_chat_message(
@ -174,9 +185,7 @@ def stream_answer_objects(
prompt_config = PromptConfig.from_model(prompt) prompt_config = PromptConfig.from_model(prompt)
document_pruning_config = DocumentPruningConfig( document_pruning_config = DocumentPruningConfig(
max_chunks=int( max_chunks=int(
chat_session.persona.num_chunks persona.num_chunks if persona.num_chunks is not None else default_num_chunks
if chat_session.persona.num_chunks is not None
else default_num_chunks
), ),
max_tokens=max_document_tokens, max_tokens=max_document_tokens,
) )
@ -187,16 +196,16 @@ def stream_answer_objects(
evaluation_type=LLMEvaluationType.SKIP evaluation_type=LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type, else query_req.evaluation_type,
persona=chat_session.persona, persona=persona,
retrieval_options=query_req.retrieval_options, retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config, prompt_config=prompt_config,
llm=llm, llm=llm,
fast_llm=fast_llm, fast_llm=fast_llm,
pruning_config=document_pruning_config, pruning_config=document_pruning_config,
bypass_acl=bypass_acl,
chunks_above=query_req.chunks_above, chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below, chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc, full_doc=query_req.full_doc,
bypass_acl=bypass_acl,
) )
answer_config = AnswerStyleConfig( answer_config = AnswerStyleConfig(
@ -209,13 +218,15 @@ def stream_answer_objects(
question=query_msg.message, question=query_msg.message,
answer_style_config=answer_config, answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt), prompt_config=PromptConfig.from_model(prompt),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)), llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
single_message_history=history_str, single_message_history=history_str,
tools=[search_tool], tools=[search_tool] if search_tool else [],
force_use_tool=ForceUseTool( force_use_tool=(
force_use=True, ForceUseTool(
tool_name=search_tool.name, tool_name=search_tool.name,
args={"query": rephrased_query}, args={"query": rephrased_query},
force_use=True,
)
), ),
# for now, don't use tool calling for this flow, as we haven't # for now, don't use tool calling for this flow, as we haven't
# tested quotes with tool calling too much yet # tested quotes with tool calling too much yet
@ -223,9 +234,7 @@ def stream_answer_objects(
return_contexts=query_req.return_contexts, return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation, skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
) )
# won't be any ImageGenerationDisplay responses since that tool is never passed in # won't be any ImageGenerationDisplay responses since that tool is never passed in
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output): for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these # for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse): if isinstance(packet, ToolResponse):
@ -261,6 +270,7 @@ def stream_answer_objects(
applied_time_cutoff=search_response_summary.final_filters.time_cutoff, applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
recency_bias_multiplier=search_response_summary.recency_bias_multiplier, recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
) )
yield initial_response yield initial_response
elif packet.id == SEARCH_DOC_CONTENT_ID: elif packet.id == SEARCH_DOC_CONTENT_ID:
@ -287,6 +297,7 @@ def stream_answer_objects(
relevance_summary=evaluation_response, relevance_summary=evaluation_response,
) )
yield evaluation_response yield evaluation_response
else: else:
yield packet yield packet

View File

@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from pydantic import Field from pydantic import Field
from pydantic import model_validator from pydantic import model_validator
@ -8,6 +10,8 @@ from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.search.enums import LLMEvaluationType from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails from danswer.search.models import RetrievalDetails
@ -23,10 +27,49 @@ class ThreadMessage(BaseModel):
role: MessageType = MessageType.USER role: MessageType = MessageType.USER
class PromptConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class DocumentSetConfig(BaseModel):
id: int
class ToolConfig(BaseModel):
id: int
class PersonaConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
class DirectQARequest(ChunkContext): class DirectQARequest(ChunkContext):
persona_config: PersonaConfig | None = None
persona_id: int | None = None
messages: list[ThreadMessage] messages: list[ThreadMessage]
prompt_id: int | None prompt_id: int | None = None
persona_id: int
multilingual_query_expansion: list[str] | None = None multilingual_query_expansion: list[str] | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails) retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None rerank_settings: RerankingDetails | None = None
@ -43,6 +86,12 @@ class DirectQARequest(ChunkContext):
# If True, skips generative an AI response to the search query # If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False skip_gen_ai_answer_generation: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "DirectQARequest":
if (self.persona_config is None) == (self.persona_id is None):
raise ValueError("Exactly one of persona_config or persona_id must be set")
return self
@model_validator(mode="after") @model_validator(mode="after")
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest": def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
if self.chain_of_thought and self.prompt_id is not None: if self.chain_of_thought and self.prompt_id is not None:

View File

@ -164,7 +164,7 @@ def get_chat_session(
chat_session_id=session_id, chat_session_id=session_id,
description=chat_session.description, description=chat_session.description,
persona_id=chat_session.persona_id, persona_id=chat_session.persona_id,
persona_name=chat_session.persona.name, persona_name=chat_session.persona.name if chat_session.persona else None,
current_alternate_model=chat_session.current_alternate_model, current_alternate_model=chat_session.current_alternate_model,
messages=[ messages=[
translate_db_message_to_chat_message_detail( translate_db_message_to_chat_message_detail(

View File

@ -136,7 +136,7 @@ class RenameChatSessionResponse(BaseModel):
class ChatSessionDetails(BaseModel): class ChatSessionDetails(BaseModel):
id: int id: int
name: str name: str
persona_id: int persona_id: int | None = None
time_created: str time_created: str
shared_status: ChatSessionSharedStatus shared_status: ChatSessionSharedStatus
folder_id: int | None = None folder_id: int | None = None
@ -196,8 +196,8 @@ class SearchSessionDetailResponse(BaseModel):
class ChatSessionDetailResponse(BaseModel): class ChatSessionDetailResponse(BaseModel):
chat_session_id: int chat_session_id: int
description: str description: str
persona_id: int persona_id: int | None = None
persona_name: str persona_name: str | None
messages: list[ChatMessageDetail] messages: list[ChatMessageDetail]
time_created: datetime time_created: datetime
shared_status: ChatSessionSharedStatus shared_status: ChatSessionSharedStatus

View File

@ -32,6 +32,7 @@ from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import (
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
logger = setup_logger() logger = setup_logger()
@ -133,12 +134,23 @@ def get_answer_with_quote(
query = query_request.messages[0].message query = query_request.messages[0].message
logger.notice(f"Received query for one shot answer API with quotes: {query}") logger.notice(f"Received query for one shot answer API with quotes: {query}")
persona = get_persona_by_id( if query_request.persona_config is not None:
persona_id=query_request.persona_id, new_persona = create_temporary_persona(
user=user, db_session=db_session,
db_session=db_session, persona_config=query_request.persona_config,
is_for_edit=False, user=user,
) )
persona = new_persona
elif query_request.persona_id is not None:
persona = get_persona_by_id(
persona_id=query_request.persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
else:
raise KeyError("Must provide persona ID or Persona Config")
llm = get_main_llm_from_tuple( llm = get_main_llm_from_tuple(
get_default_llms() if not persona else get_llms_for_persona(persona) get_default_llms() if not persona else get_llms_for_persona(persona)

View File

@ -0,0 +1,83 @@
from typing import cast
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import is_user_admin
from danswer.db.llm import fetch_existing_doc_sets
from danswer.db.llm import fetch_existing_tools
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.one_shot_answer.models import PersonaConfig
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
def create_temporary_persona(
persona_config: PersonaConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
persona.prompts = [
Prompt(
name=p.name,
description=p.description,
system_prompt=p.system_prompt,
task_prompt=p.task_prompt,
include_citations=p.include_citations,
datetime_aware=p.datetime_aware,
)
for p in persona_config.prompts
]
elif persona_config.prompt_ids:
persona.prompts = get_prompts_by_ids(
db_session=db_session, prompt_ids=persona_config.prompt_ids
)
persona.tools = []
if persona_config.custom_tools_openapi:
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema(schema),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona

View File

@ -87,7 +87,7 @@ class ChatSessionMinimal(BaseModel):
name: str | None name: str | None
first_user_message: str first_user_message: str
first_ai_message: str first_ai_message: str
persona_name: str persona_name: str | None
time_created: datetime time_created: datetime
feedback_type: QAFeedbackType | Literal["mixed"] | None feedback_type: QAFeedbackType | Literal["mixed"] | None
@ -97,7 +97,7 @@ class ChatSessionSnapshot(BaseModel):
user_email: str user_email: str
name: str | None name: str | None
messages: list[MessageSnapshot] messages: list[MessageSnapshot]
persona_name: str persona_name: str | None
time_created: datetime time_created: datetime
@ -111,7 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
retrieved_documents: list[AbridgedSearchDoc] retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None feedback_type: QAFeedbackType | None
feedback_text: str | None feedback_text: str | None
persona_name: str persona_name: str | None
user_email: str user_email: str
time_created: datetime time_created: datetime
@ -145,7 +145,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
for ind, (user_message, ai_message) in enumerate(message_pairs) for ind, (user_message, ai_message) in enumerate(message_pairs)
] ]
def to_json(self) -> dict[str, str]: def to_json(self) -> dict[str, str | None]:
return { return {
"chat_session_id": str(self.chat_session_id), "chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num), "message_pair_num": str(self.message_pair_num),
@ -235,7 +235,9 @@ def fetch_and_process_chat_session_history_minimal(
name=chat_session.description, name=chat_session.description,
first_user_message=first_user_message, first_user_message=first_user_message,
first_ai_message=first_ai_message, first_ai_message=first_ai_message,
persona_name=chat_session.persona.name, persona_name=chat_session.persona.name
if chat_session.persona
else None,
time_created=chat_session.time_created, time_created=chat_session.time_created,
feedback_type=feedback_type, feedback_type=feedback_type,
) )
@ -300,7 +302,7 @@ def snapshot_from_chat_session(
for message in messages for message in messages
if message.message_type != MessageType.SYSTEM if message.message_type != MessageType.SYSTEM
], ],
persona_name=chat_session.persona.name, persona_name=chat_session.persona.name if chat_session.persona else None,
time_created=chat_session.time_created, time_created=chat_session.time_created,
) )