mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-02 11:09:20 +02:00
Add ability to specify persona in API request (#2302)
* persona * all prepared excluding configuration * more sensical model structure * update tstream * type updates * rm * quick and simple updates * minor updates * te * ensure typing + naming * remove old todo + rebase update * remove unnecessary check
This commit is contained in:
parent
df464fc54b
commit
2dd3870504
@ -0,0 +1,31 @@
|
||||
"""add nullable to persona id in Chat Session
|
||||
|
||||
Revision ID: c99d76fcd298
|
||||
Revises: 5c7fdadae813
|
||||
Create Date: 2024-07-09 19:27:01.579697
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c99d76fcd298"
|
||||
down_revision = "5c7fdadae813"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
)
|
@ -675,9 +675,11 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
dedupe_docs=(
|
||||
retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False
|
||||
),
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
@ -786,16 +788,18 @@ def stream_chat_message_objects(
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else [],
|
||||
tool_calls=(
|
||||
[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
|
@ -5,6 +5,7 @@ from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import HTTPException
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
@ -153,15 +154,23 @@ def handle_regular_answer(
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
if new_message_request.persona_config:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Slack bot does not support persona config",
|
||||
)
|
||||
|
||||
elif new_message_request.persona_id:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
|
@ -226,7 +226,7 @@ def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
|
@ -4,9 +4,11 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import Tool as ToolModel
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
@ -103,6 +105,20 @@ def fetch_existing_embedding_providers(
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_doc_sets(
|
||||
db_session: Session, doc_ids: list[int]
|
||||
) -> list[DocumentSet]:
|
||||
return list(
|
||||
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
|
||||
return list(
|
||||
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
|
@ -866,7 +866,9 @@ class ChatSession(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
# One-shot direct answering, currently the two types of chats are not mixed
|
||||
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@ -900,7 +902,6 @@ class ChatSession(Base):
|
||||
prompt_override: Mapped[PromptOverride | None] = mapped_column(
|
||||
PydanticType(PromptOverride), nullable=True
|
||||
)
|
||||
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@ -909,7 +910,6 @@ class ChatSession(Base):
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
||||
folder: Mapped["ChatFolder"] = relationship(
|
||||
"ChatFolder", back_populates="chat_sessions"
|
||||
|
@ -563,13 +563,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
|
||||
"""Unsafe, can fetch prompts from all users"""
|
||||
if not prompt_ids:
|
||||
return []
|
||||
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
|
||||
prompts = db_session.scalars(
|
||||
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
|
||||
).all()
|
||||
|
||||
return prompts
|
||||
return list(prompts)
|
||||
|
||||
|
||||
def get_prompt_by_id(
|
||||
|
@ -26,6 +26,7 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.chat import update_search_docs_table_with_relevance
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompt_by_id
|
||||
from danswer.llm.answering.answer import Answer
|
||||
@ -60,7 +61,7 @@ from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -118,7 +119,17 @@ def stream_answer_objects(
|
||||
one_shot=True,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
)
|
||||
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
|
||||
|
||||
temporary_persona: Persona | None = None
|
||||
if query_req.persona_config is not None:
|
||||
new_persona = create_temporary_persona(
|
||||
db_session=db_session, persona_config=query_req.persona_config, user=user
|
||||
)
|
||||
temporary_persona = new_persona
|
||||
|
||||
persona = temporary_persona if temporary_persona else chat_session.persona
|
||||
|
||||
llm, fast_llm = get_llms_for_persona(persona=persona)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
@ -153,11 +164,11 @@ def stream_answer_objects(
|
||||
prompt_id=query_req.prompt_id, user=None, db_session=db_session
|
||||
)
|
||||
if prompt is None:
|
||||
if not chat_session.persona.prompts:
|
||||
if not persona.prompts:
|
||||
raise RuntimeError(
|
||||
"Persona does not have any prompts - this should never happen"
|
||||
)
|
||||
prompt = chat_session.persona.prompts[0]
|
||||
prompt = persona.prompts[0]
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
@ -174,9 +185,7 @@ def stream_answer_objects(
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
chat_session.persona.num_chunks
|
||||
if chat_session.persona.num_chunks is not None
|
||||
else default_num_chunks
|
||||
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
|
||||
),
|
||||
max_tokens=max_document_tokens,
|
||||
)
|
||||
@ -187,16 +196,16 @@ def stream_answer_objects(
|
||||
evaluation_type=LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type,
|
||||
persona=chat_session.persona,
|
||||
persona=persona,
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
bypass_acl=bypass_acl,
|
||||
chunks_above=query_req.chunks_above,
|
||||
chunks_below=query_req.chunks_below,
|
||||
full_doc=query_req.full_doc,
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
@ -209,13 +218,15 @@ def stream_answer_objects(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
|
||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
tools=[search_tool] if search_tool else [],
|
||||
force_use_tool=(
|
||||
ForceUseTool(
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
force_use=True,
|
||||
)
|
||||
),
|
||||
# for now, don't use tool calling for this flow, as we haven't
|
||||
# tested quotes with tool calling too much yet
|
||||
@ -223,9 +234,7 @@ def stream_answer_objects(
|
||||
return_contexts=query_req.return_contexts,
|
||||
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
||||
|
||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
# for one-shot flow, don't currently do anything with these
|
||||
if isinstance(packet, ToolResponse):
|
||||
@ -261,6 +270,7 @@ def stream_answer_objects(
|
||||
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
|
||||
)
|
||||
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
@ -287,6 +297,7 @@ def stream_answer_objects(
|
||||
relevance_summary=evaluation_response,
|
||||
)
|
||||
yield evaluation_response
|
||||
|
||||
else:
|
||||
yield packet
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
@ -8,6 +10,8 @@ from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
@ -23,10 +27,49 @@ class ThreadMessage(BaseModel):
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class DocumentSetConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PersonaConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DirectQARequest(ChunkContext):
|
||||
persona_config: PersonaConfig | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None
|
||||
persona_id: int
|
||||
prompt_id: int | None = None
|
||||
multilingual_query_expansion: list[str] | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
@ -43,6 +86,12 @@ class DirectQARequest(ChunkContext):
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "DirectQARequest":
|
||||
if (self.persona_config is None) == (self.persona_id is None):
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
|
||||
if self.chain_of_thought and self.prompt_id is not None:
|
||||
|
@ -164,7 +164,7 @@ def get_chat_session(
|
||||
chat_session_id=session_id,
|
||||
description=chat_session.description,
|
||||
persona_id=chat_session.persona_id,
|
||||
persona_name=chat_session.persona.name,
|
||||
persona_name=chat_session.persona.name if chat_session.persona else None,
|
||||
current_alternate_model=chat_session.current_alternate_model,
|
||||
messages=[
|
||||
translate_db_message_to_chat_message_detail(
|
||||
|
@ -136,7 +136,7 @@ class RenameChatSessionResponse(BaseModel):
|
||||
class ChatSessionDetails(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
persona_id: int
|
||||
persona_id: int | None = None
|
||||
time_created: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
@ -196,8 +196,8 @@ class SearchSessionDetailResponse(BaseModel):
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: int
|
||||
description: str
|
||||
persona_id: int
|
||||
persona_name: str
|
||||
persona_id: int | None = None
|
||||
persona_name: str | None
|
||||
messages: list[ChatMessageDetail]
|
||||
time_created: datetime
|
||||
shared_status: ChatSessionSharedStatus
|
||||
|
@ -32,6 +32,7 @@ from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
||||
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
||||
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
|
||||
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
|
||||
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -133,12 +134,23 @@ def get_answer_with_quote(
|
||||
query = query_request.messages[0].message
|
||||
logger.notice(f"Received query for one shot answer API with quotes: {query}")
|
||||
|
||||
persona = get_persona_by_id(
|
||||
persona_id=query_request.persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
if query_request.persona_config is not None:
|
||||
new_persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=query_request.persona_config,
|
||||
user=user,
|
||||
)
|
||||
persona = new_persona
|
||||
|
||||
elif query_request.persona_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=query_request.persona_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
else:
|
||||
raise KeyError("Must provide persona ID or Persona Config")
|
||||
|
||||
llm = get_main_llm_from_tuple(
|
||||
get_default_llms() if not persona else get_llms_for_persona(persona)
|
||||
|
83
backend/ee/danswer/server/query_and_chat/utils.py
Normal file
83
backend/ee/danswer/server/query_and_chat/utils.py
Normal file
@ -0,0 +1,83 @@
|
||||
from typing import cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.db.llm import fetch_existing_doc_sets
|
||||
from danswer.db.llm import fetch_existing_tools
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.one_shot_answer.models import PersonaConfig
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(schema),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
@ -87,7 +87,7 @@ class ChatSessionMinimal(BaseModel):
|
||||
name: str | None
|
||||
first_user_message: str
|
||||
first_ai_message: str
|
||||
persona_name: str
|
||||
persona_name: str | None
|
||||
time_created: datetime
|
||||
feedback_type: QAFeedbackType | Literal["mixed"] | None
|
||||
|
||||
@ -97,7 +97,7 @@ class ChatSessionSnapshot(BaseModel):
|
||||
user_email: str
|
||||
name: str | None
|
||||
messages: list[MessageSnapshot]
|
||||
persona_name: str
|
||||
persona_name: str | None
|
||||
time_created: datetime
|
||||
|
||||
|
||||
@ -111,7 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
|
||||
retrieved_documents: list[AbridgedSearchDoc]
|
||||
feedback_type: QAFeedbackType | None
|
||||
feedback_text: str | None
|
||||
persona_name: str
|
||||
persona_name: str | None
|
||||
user_email: str
|
||||
time_created: datetime
|
||||
|
||||
@ -145,7 +145,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
|
||||
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
||||
]
|
||||
|
||||
def to_json(self) -> dict[str, str]:
|
||||
def to_json(self) -> dict[str, str | None]:
|
||||
return {
|
||||
"chat_session_id": str(self.chat_session_id),
|
||||
"message_pair_num": str(self.message_pair_num),
|
||||
@ -235,7 +235,9 @@ def fetch_and_process_chat_session_history_minimal(
|
||||
name=chat_session.description,
|
||||
first_user_message=first_user_message,
|
||||
first_ai_message=first_ai_message,
|
||||
persona_name=chat_session.persona.name,
|
||||
persona_name=chat_session.persona.name
|
||||
if chat_session.persona
|
||||
else None,
|
||||
time_created=chat_session.time_created,
|
||||
feedback_type=feedback_type,
|
||||
)
|
||||
@ -300,7 +302,7 @@ def snapshot_from_chat_session(
|
||||
for message in messages
|
||||
if message.message_type != MessageType.SYSTEM
|
||||
],
|
||||
persona_name=chat_session.persona.name,
|
||||
persona_name=chat_session.persona.name if chat_session.persona else None,
|
||||
time_created=chat_session.time_created,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user