mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-06 04:59:24 +02:00
Add ability to specify persona in API request (#2302)
* persona * all prepared excluding configuration * more sensical model structure * update tstream * type updates * rm * quick and simple updates * minor updates * te * ensure typing + naming * remove old todo + rebase update * remove unnecessary check
This commit is contained in:
parent
df464fc54b
commit
2dd3870504
@ -0,0 +1,31 @@
|
|||||||
|
"""add nullable to persona id in Chat Session
|
||||||
|
|
||||||
|
Revision ID: c99d76fcd298
|
||||||
|
Revises: 5c7fdadae813
|
||||||
|
Create Date: 2024-07-09 19:27:01.579697
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "c99d76fcd298"
|
||||||
|
down_revision = "5c7fdadae813"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
"chat_session",
|
||||||
|
"persona_id",
|
||||||
|
existing_type=sa.INTEGER(),
|
||||||
|
nullable=False,
|
||||||
|
)
|
@ -675,9 +675,11 @@ def stream_chat_message_objects(
|
|||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
selected_search_docs=selected_db_search_docs,
|
selected_search_docs=selected_db_search_docs,
|
||||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||||
dedupe_docs=retrieval_options.dedupe_docs
|
dedupe_docs=(
|
||||||
if retrieval_options
|
retrieval_options.dedupe_docs
|
||||||
else False,
|
if retrieval_options
|
||||||
|
else False
|
||||||
|
),
|
||||||
)
|
)
|
||||||
yield qa_docs_response
|
yield qa_docs_response
|
||||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||||
@ -786,16 +788,18 @@ def stream_chat_message_objects(
|
|||||||
if message_specific_citations
|
if message_specific_citations
|
||||||
else None,
|
else None,
|
||||||
error=None,
|
error=None,
|
||||||
tool_calls=[
|
tool_calls=(
|
||||||
ToolCall(
|
[
|
||||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
ToolCall(
|
||||||
tool_name=tool_result.tool_name,
|
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||||
tool_arguments=tool_result.tool_args,
|
tool_name=tool_result.tool_name,
|
||||||
tool_result=tool_result.tool_result,
|
tool_arguments=tool_result.tool_args,
|
||||||
)
|
tool_result=tool_result.tool_result,
|
||||||
]
|
)
|
||||||
if tool_result
|
]
|
||||||
else [],
|
if tool_result
|
||||||
|
else []
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Committing messages")
|
logger.debug("Committing messages")
|
||||||
|
@ -5,6 +5,7 @@ from typing import cast
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
from retry import retry
|
from retry import retry
|
||||||
from slack_sdk import WebClient
|
from slack_sdk import WebClient
|
||||||
from slack_sdk.models.blocks import DividerBlock
|
from slack_sdk.models.blocks import DividerBlock
|
||||||
@ -153,15 +154,23 @@ def handle_regular_answer(
|
|||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
if len(new_message_request.messages) > 1:
|
if len(new_message_request.messages) > 1:
|
||||||
persona = cast(
|
if new_message_request.persona_config:
|
||||||
Persona,
|
raise HTTPException(
|
||||||
fetch_persona_by_id(
|
status_code=403,
|
||||||
db_session,
|
detail="Slack bot does not support persona config",
|
||||||
new_message_request.persona_id,
|
)
|
||||||
user=None,
|
|
||||||
get_editable=False,
|
elif new_message_request.persona_id:
|
||||||
),
|
persona = cast(
|
||||||
)
|
Persona,
|
||||||
|
fetch_persona_by_id(
|
||||||
|
db_session,
|
||||||
|
new_message_request.persona_id,
|
||||||
|
user=None,
|
||||||
|
get_editable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
llm, _ = get_llms_for_persona(persona)
|
llm, _ = get_llms_for_persona(persona)
|
||||||
|
|
||||||
# In cases of threads, split the available tokens between docs and thread context
|
# In cases of threads, split the available tokens between docs and thread context
|
||||||
|
@ -226,7 +226,7 @@ def create_chat_session(
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
description: str,
|
description: str,
|
||||||
user_id: UUID | None,
|
user_id: UUID | None,
|
||||||
persona_id: int,
|
persona_id: int | None, # Can be none if temporary persona is used
|
||||||
llm_override: LLMOverride | None = None,
|
llm_override: LLMOverride | None = None,
|
||||||
prompt_override: PromptOverride | None = None,
|
prompt_override: PromptOverride | None = None,
|
||||||
one_shot: bool = False,
|
one_shot: bool = False,
|
||||||
|
@ -4,9 +4,11 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||||
|
from danswer.db.models import DocumentSet
|
||||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||||
from danswer.db.models import LLMProvider__UserGroup
|
from danswer.db.models import LLMProvider__UserGroup
|
||||||
from danswer.db.models import SearchSettings
|
from danswer.db.models import SearchSettings
|
||||||
|
from danswer.db.models import Tool as ToolModel
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.db.models import User__UserGroup
|
from danswer.db.models import User__UserGroup
|
||||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||||
@ -103,6 +105,20 @@ def fetch_existing_embedding_providers(
|
|||||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_existing_doc_sets(
|
||||||
|
db_session: Session, doc_ids: list[int]
|
||||||
|
) -> list[DocumentSet]:
|
||||||
|
return list(
|
||||||
|
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
|
||||||
|
return list(
|
||||||
|
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fetch_existing_llm_providers(
|
def fetch_existing_llm_providers(
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
user: User | None = None,
|
user: User | None = None,
|
||||||
|
@ -866,7 +866,9 @@ class ChatSession(Base):
|
|||||||
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
|
persona_id: Mapped[int | None] = mapped_column(
|
||||||
|
ForeignKey("persona.id"), nullable=True
|
||||||
|
)
|
||||||
description: Mapped[str] = mapped_column(Text)
|
description: Mapped[str] = mapped_column(Text)
|
||||||
# One-shot direct answering, currently the two types of chats are not mixed
|
# One-shot direct answering, currently the two types of chats are not mixed
|
||||||
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
@ -900,7 +902,6 @@ class ChatSession(Base):
|
|||||||
prompt_override: Mapped[PromptOverride | None] = mapped_column(
|
prompt_override: Mapped[PromptOverride | None] = mapped_column(
|
||||||
PydanticType(PromptOverride), nullable=True
|
PydanticType(PromptOverride), nullable=True
|
||||||
)
|
)
|
||||||
|
|
||||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||||
DateTime(timezone=True),
|
DateTime(timezone=True),
|
||||||
server_default=func.now(),
|
server_default=func.now(),
|
||||||
@ -909,7 +910,6 @@ class ChatSession(Base):
|
|||||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||||
DateTime(timezone=True), server_default=func.now()
|
DateTime(timezone=True), server_default=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
||||||
folder: Mapped["ChatFolder"] = relationship(
|
folder: Mapped["ChatFolder"] = relationship(
|
||||||
"ChatFolder", back_populates="chat_sessions"
|
"ChatFolder", back_populates="chat_sessions"
|
||||||
|
@ -563,13 +563,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
|
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
|
||||||
"""Unsafe, can fetch prompts from all users"""
|
"""Unsafe, can fetch prompts from all users"""
|
||||||
if not prompt_ids:
|
if not prompt_ids:
|
||||||
return []
|
return []
|
||||||
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
|
prompts = db_session.scalars(
|
||||||
|
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
|
||||||
|
).all()
|
||||||
|
|
||||||
return prompts
|
return list(prompts)
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_by_id(
|
def get_prompt_by_id(
|
||||||
|
@ -26,6 +26,7 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail
|
|||||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||||
from danswer.db.chat import update_search_docs_table_with_relevance
|
from danswer.db.chat import update_search_docs_table_with_relevance
|
||||||
from danswer.db.engine import get_session_context_manager
|
from danswer.db.engine import get_session_context_manager
|
||||||
|
from danswer.db.models import Persona
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.db.persona import get_prompt_by_id
|
from danswer.db.persona import get_prompt_by_id
|
||||||
from danswer.llm.answering.answer import Answer
|
from danswer.llm.answering.answer import Answer
|
||||||
@ -60,7 +61,7 @@ from danswer.tools.tool import ToolResponse
|
|||||||
from danswer.tools.tool_runner import ToolCallKickoff
|
from danswer.tools.tool_runner import ToolCallKickoff
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.timing import log_generator_function_time
|
from danswer.utils.timing import log_generator_function_time
|
||||||
|
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@ -118,7 +119,17 @@ def stream_answer_objects(
|
|||||||
one_shot=True,
|
one_shot=True,
|
||||||
danswerbot_flow=danswerbot_flow,
|
danswerbot_flow=danswerbot_flow,
|
||||||
)
|
)
|
||||||
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
|
|
||||||
|
temporary_persona: Persona | None = None
|
||||||
|
if query_req.persona_config is not None:
|
||||||
|
new_persona = create_temporary_persona(
|
||||||
|
db_session=db_session, persona_config=query_req.persona_config, user=user
|
||||||
|
)
|
||||||
|
temporary_persona = new_persona
|
||||||
|
|
||||||
|
persona = temporary_persona if temporary_persona else chat_session.persona
|
||||||
|
|
||||||
|
llm, fast_llm = get_llms_for_persona(persona=persona)
|
||||||
|
|
||||||
llm_tokenizer = get_tokenizer(
|
llm_tokenizer = get_tokenizer(
|
||||||
model_name=llm.config.model_name,
|
model_name=llm.config.model_name,
|
||||||
@ -153,11 +164,11 @@ def stream_answer_objects(
|
|||||||
prompt_id=query_req.prompt_id, user=None, db_session=db_session
|
prompt_id=query_req.prompt_id, user=None, db_session=db_session
|
||||||
)
|
)
|
||||||
if prompt is None:
|
if prompt is None:
|
||||||
if not chat_session.persona.prompts:
|
if not persona.prompts:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Persona does not have any prompts - this should never happen"
|
"Persona does not have any prompts - this should never happen"
|
||||||
)
|
)
|
||||||
prompt = chat_session.persona.prompts[0]
|
prompt = persona.prompts[0]
|
||||||
|
|
||||||
# Create the first User query message
|
# Create the first User query message
|
||||||
new_user_message = create_new_chat_message(
|
new_user_message = create_new_chat_message(
|
||||||
@ -174,9 +185,7 @@ def stream_answer_objects(
|
|||||||
prompt_config = PromptConfig.from_model(prompt)
|
prompt_config = PromptConfig.from_model(prompt)
|
||||||
document_pruning_config = DocumentPruningConfig(
|
document_pruning_config = DocumentPruningConfig(
|
||||||
max_chunks=int(
|
max_chunks=int(
|
||||||
chat_session.persona.num_chunks
|
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
|
||||||
if chat_session.persona.num_chunks is not None
|
|
||||||
else default_num_chunks
|
|
||||||
),
|
),
|
||||||
max_tokens=max_document_tokens,
|
max_tokens=max_document_tokens,
|
||||||
)
|
)
|
||||||
@ -187,16 +196,16 @@ def stream_answer_objects(
|
|||||||
evaluation_type=LLMEvaluationType.SKIP
|
evaluation_type=LLMEvaluationType.SKIP
|
||||||
if DISABLE_LLM_DOC_RELEVANCE
|
if DISABLE_LLM_DOC_RELEVANCE
|
||||||
else query_req.evaluation_type,
|
else query_req.evaluation_type,
|
||||||
persona=chat_session.persona,
|
persona=persona,
|
||||||
retrieval_options=query_req.retrieval_options,
|
retrieval_options=query_req.retrieval_options,
|
||||||
prompt_config=prompt_config,
|
prompt_config=prompt_config,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
fast_llm=fast_llm,
|
fast_llm=fast_llm,
|
||||||
pruning_config=document_pruning_config,
|
pruning_config=document_pruning_config,
|
||||||
|
bypass_acl=bypass_acl,
|
||||||
chunks_above=query_req.chunks_above,
|
chunks_above=query_req.chunks_above,
|
||||||
chunks_below=query_req.chunks_below,
|
chunks_below=query_req.chunks_below,
|
||||||
full_doc=query_req.full_doc,
|
full_doc=query_req.full_doc,
|
||||||
bypass_acl=bypass_acl,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
answer_config = AnswerStyleConfig(
|
answer_config = AnswerStyleConfig(
|
||||||
@ -209,13 +218,15 @@ def stream_answer_objects(
|
|||||||
question=query_msg.message,
|
question=query_msg.message,
|
||||||
answer_style_config=answer_config,
|
answer_style_config=answer_config,
|
||||||
prompt_config=PromptConfig.from_model(prompt),
|
prompt_config=PromptConfig.from_model(prompt),
|
||||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
|
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
|
||||||
single_message_history=history_str,
|
single_message_history=history_str,
|
||||||
tools=[search_tool],
|
tools=[search_tool] if search_tool else [],
|
||||||
force_use_tool=ForceUseTool(
|
force_use_tool=(
|
||||||
force_use=True,
|
ForceUseTool(
|
||||||
tool_name=search_tool.name,
|
tool_name=search_tool.name,
|
||||||
args={"query": rephrased_query},
|
args={"query": rephrased_query},
|
||||||
|
force_use=True,
|
||||||
|
)
|
||||||
),
|
),
|
||||||
# for now, don't use tool calling for this flow, as we haven't
|
# for now, don't use tool calling for this flow, as we haven't
|
||||||
# tested quotes with tool calling too much yet
|
# tested quotes with tool calling too much yet
|
||||||
@ -223,9 +234,7 @@ def stream_answer_objects(
|
|||||||
return_contexts=query_req.return_contexts,
|
return_contexts=query_req.return_contexts,
|
||||||
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
||||||
)
|
)
|
||||||
|
|
||||||
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
||||||
|
|
||||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||||
# for one-shot flow, don't currently do anything with these
|
# for one-shot flow, don't currently do anything with these
|
||||||
if isinstance(packet, ToolResponse):
|
if isinstance(packet, ToolResponse):
|
||||||
@ -261,6 +270,7 @@ def stream_answer_objects(
|
|||||||
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
|
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
|
||||||
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
|
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield initial_response
|
yield initial_response
|
||||||
|
|
||||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||||
@ -287,6 +297,7 @@ def stream_answer_objects(
|
|||||||
relevance_summary=evaluation_response,
|
relevance_summary=evaluation_response,
|
||||||
)
|
)
|
||||||
yield evaluation_response
|
yield evaluation_response
|
||||||
|
|
||||||
else:
|
else:
|
||||||
yield packet
|
yield packet
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
@ -8,6 +10,8 @@ from danswer.chat.models import DanswerQuotes
|
|||||||
from danswer.chat.models import QADocsResponse
|
from danswer.chat.models import QADocsResponse
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.search.enums import LLMEvaluationType
|
from danswer.search.enums import LLMEvaluationType
|
||||||
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
|
from danswer.search.enums import SearchType
|
||||||
from danswer.search.models import ChunkContext
|
from danswer.search.models import ChunkContext
|
||||||
from danswer.search.models import RerankingDetails
|
from danswer.search.models import RerankingDetails
|
||||||
from danswer.search.models import RetrievalDetails
|
from danswer.search.models import RetrievalDetails
|
||||||
@ -23,10 +27,49 @@ class ThreadMessage(BaseModel):
|
|||||||
role: MessageType = MessageType.USER
|
role: MessageType = MessageType.USER
|
||||||
|
|
||||||
|
|
||||||
|
class PromptConfig(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
system_prompt: str
|
||||||
|
task_prompt: str = ""
|
||||||
|
include_citations: bool = True
|
||||||
|
datetime_aware: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentSetConfig(BaseModel):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class ToolConfig(BaseModel):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class PersonaConfig(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
search_type: SearchType = SearchType.SEMANTIC
|
||||||
|
num_chunks: float | None = None
|
||||||
|
llm_relevance_filter: bool = False
|
||||||
|
llm_filter_extraction: bool = False
|
||||||
|
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||||
|
llm_model_provider_override: str | None = None
|
||||||
|
llm_model_version_override: str | None = None
|
||||||
|
|
||||||
|
prompts: list[PromptConfig] = Field(default_factory=list)
|
||||||
|
prompt_ids: list[int] = Field(default_factory=list)
|
||||||
|
|
||||||
|
document_set_ids: list[int] = Field(default_factory=list)
|
||||||
|
tools: list[ToolConfig] = Field(default_factory=list)
|
||||||
|
tool_ids: list[int] = Field(default_factory=list)
|
||||||
|
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DirectQARequest(ChunkContext):
|
class DirectQARequest(ChunkContext):
|
||||||
|
persona_config: PersonaConfig | None = None
|
||||||
|
persona_id: int | None = None
|
||||||
|
|
||||||
messages: list[ThreadMessage]
|
messages: list[ThreadMessage]
|
||||||
prompt_id: int | None
|
prompt_id: int | None = None
|
||||||
persona_id: int
|
|
||||||
multilingual_query_expansion: list[str] | None = None
|
multilingual_query_expansion: list[str] | None = None
|
||||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||||
rerank_settings: RerankingDetails | None = None
|
rerank_settings: RerankingDetails | None = None
|
||||||
@ -43,6 +86,12 @@ class DirectQARequest(ChunkContext):
|
|||||||
# If True, skips generative an AI response to the search query
|
# If True, skips generative an AI response to the search query
|
||||||
skip_gen_ai_answer_generation: bool = False
|
skip_gen_ai_answer_generation: bool = False
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_persona_fields(self) -> "DirectQARequest":
|
||||||
|
if (self.persona_config is None) == (self.persona_id is None):
|
||||||
|
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||||
|
return self
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
|
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
|
||||||
if self.chain_of_thought and self.prompt_id is not None:
|
if self.chain_of_thought and self.prompt_id is not None:
|
||||||
|
@ -164,7 +164,7 @@ def get_chat_session(
|
|||||||
chat_session_id=session_id,
|
chat_session_id=session_id,
|
||||||
description=chat_session.description,
|
description=chat_session.description,
|
||||||
persona_id=chat_session.persona_id,
|
persona_id=chat_session.persona_id,
|
||||||
persona_name=chat_session.persona.name,
|
persona_name=chat_session.persona.name if chat_session.persona else None,
|
||||||
current_alternate_model=chat_session.current_alternate_model,
|
current_alternate_model=chat_session.current_alternate_model,
|
||||||
messages=[
|
messages=[
|
||||||
translate_db_message_to_chat_message_detail(
|
translate_db_message_to_chat_message_detail(
|
||||||
|
@ -136,7 +136,7 @@ class RenameChatSessionResponse(BaseModel):
|
|||||||
class ChatSessionDetails(BaseModel):
|
class ChatSessionDetails(BaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
persona_id: int
|
persona_id: int | None = None
|
||||||
time_created: str
|
time_created: str
|
||||||
shared_status: ChatSessionSharedStatus
|
shared_status: ChatSessionSharedStatus
|
||||||
folder_id: int | None = None
|
folder_id: int | None = None
|
||||||
@ -196,8 +196,8 @@ class SearchSessionDetailResponse(BaseModel):
|
|||||||
class ChatSessionDetailResponse(BaseModel):
|
class ChatSessionDetailResponse(BaseModel):
|
||||||
chat_session_id: int
|
chat_session_id: int
|
||||||
description: str
|
description: str
|
||||||
persona_id: int
|
persona_id: int | None = None
|
||||||
persona_name: str
|
persona_name: str | None
|
||||||
messages: list[ChatMessageDetail]
|
messages: list[ChatMessageDetail]
|
||||||
time_created: datetime
|
time_created: datetime
|
||||||
shared_status: ChatSessionSharedStatus
|
shared_status: ChatSessionSharedStatus
|
||||||
|
@ -32,6 +32,7 @@ from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import (
|
|||||||
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
from ee.danswer.server.query_and_chat.models import DocumentSearchRequest
|
||||||
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
|
from ee.danswer.server.query_and_chat.models import StandardAnswerRequest
|
||||||
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
|
from ee.danswer.server.query_and_chat.models import StandardAnswerResponse
|
||||||
|
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -133,12 +134,23 @@ def get_answer_with_quote(
|
|||||||
query = query_request.messages[0].message
|
query = query_request.messages[0].message
|
||||||
logger.notice(f"Received query for one shot answer API with quotes: {query}")
|
logger.notice(f"Received query for one shot answer API with quotes: {query}")
|
||||||
|
|
||||||
persona = get_persona_by_id(
|
if query_request.persona_config is not None:
|
||||||
persona_id=query_request.persona_id,
|
new_persona = create_temporary_persona(
|
||||||
user=user,
|
db_session=db_session,
|
||||||
db_session=db_session,
|
persona_config=query_request.persona_config,
|
||||||
is_for_edit=False,
|
user=user,
|
||||||
)
|
)
|
||||||
|
persona = new_persona
|
||||||
|
|
||||||
|
elif query_request.persona_id is not None:
|
||||||
|
persona = get_persona_by_id(
|
||||||
|
persona_id=query_request.persona_id,
|
||||||
|
user=user,
|
||||||
|
db_session=db_session,
|
||||||
|
is_for_edit=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise KeyError("Must provide persona ID or Persona Config")
|
||||||
|
|
||||||
llm = get_main_llm_from_tuple(
|
llm = get_main_llm_from_tuple(
|
||||||
get_default_llms() if not persona else get_llms_for_persona(persona)
|
get_default_llms() if not persona else get_llms_for_persona(persona)
|
||||||
|
83
backend/ee/danswer/server/query_and_chat/utils.py
Normal file
83
backend/ee/danswer/server/query_and_chat/utils.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.auth.users import is_user_admin
|
||||||
|
from danswer.db.llm import fetch_existing_doc_sets
|
||||||
|
from danswer.db.llm import fetch_existing_tools
|
||||||
|
from danswer.db.models import Persona
|
||||||
|
from danswer.db.models import Prompt
|
||||||
|
from danswer.db.models import Tool
|
||||||
|
from danswer.db.models import User
|
||||||
|
from danswer.db.persona import get_prompts_by_ids
|
||||||
|
from danswer.one_shot_answer.models import PersonaConfig
|
||||||
|
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||||
|
|
||||||
|
|
||||||
|
def create_temporary_persona(
|
||||||
|
persona_config: PersonaConfig, db_session: Session, user: User | None = None
|
||||||
|
) -> Persona:
|
||||||
|
if not is_user_admin(user):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="User is not authorized to create a persona in one shot queries",
|
||||||
|
)
|
||||||
|
|
||||||
|
"""Create a temporary Persona object from the provided configuration."""
|
||||||
|
persona = Persona(
|
||||||
|
name=persona_config.name,
|
||||||
|
description=persona_config.description,
|
||||||
|
num_chunks=persona_config.num_chunks,
|
||||||
|
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||||
|
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||||
|
recency_bias=persona_config.recency_bias,
|
||||||
|
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||||
|
llm_model_version_override=persona_config.llm_model_version_override,
|
||||||
|
)
|
||||||
|
|
||||||
|
if persona_config.prompts:
|
||||||
|
persona.prompts = [
|
||||||
|
Prompt(
|
||||||
|
name=p.name,
|
||||||
|
description=p.description,
|
||||||
|
system_prompt=p.system_prompt,
|
||||||
|
task_prompt=p.task_prompt,
|
||||||
|
include_citations=p.include_citations,
|
||||||
|
datetime_aware=p.datetime_aware,
|
||||||
|
)
|
||||||
|
for p in persona_config.prompts
|
||||||
|
]
|
||||||
|
elif persona_config.prompt_ids:
|
||||||
|
persona.prompts = get_prompts_by_ids(
|
||||||
|
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
persona.tools = []
|
||||||
|
if persona_config.custom_tools_openapi:
|
||||||
|
for schema in persona_config.custom_tools_openapi:
|
||||||
|
tools = cast(
|
||||||
|
list[Tool],
|
||||||
|
build_custom_tools_from_openapi_schema(schema),
|
||||||
|
)
|
||||||
|
persona.tools.extend(tools)
|
||||||
|
|
||||||
|
if persona_config.tools:
|
||||||
|
tool_ids = [tool.id for tool in persona_config.tools]
|
||||||
|
persona.tools.extend(
|
||||||
|
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
if persona_config.tool_ids:
|
||||||
|
persona.tools.extend(
|
||||||
|
fetch_existing_tools(
|
||||||
|
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fetched_docs = fetch_existing_doc_sets(
|
||||||
|
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||||
|
)
|
||||||
|
persona.document_sets = fetched_docs
|
||||||
|
|
||||||
|
return persona
|
@ -87,7 +87,7 @@ class ChatSessionMinimal(BaseModel):
|
|||||||
name: str | None
|
name: str | None
|
||||||
first_user_message: str
|
first_user_message: str
|
||||||
first_ai_message: str
|
first_ai_message: str
|
||||||
persona_name: str
|
persona_name: str | None
|
||||||
time_created: datetime
|
time_created: datetime
|
||||||
feedback_type: QAFeedbackType | Literal["mixed"] | None
|
feedback_type: QAFeedbackType | Literal["mixed"] | None
|
||||||
|
|
||||||
@ -97,7 +97,7 @@ class ChatSessionSnapshot(BaseModel):
|
|||||||
user_email: str
|
user_email: str
|
||||||
name: str | None
|
name: str | None
|
||||||
messages: list[MessageSnapshot]
|
messages: list[MessageSnapshot]
|
||||||
persona_name: str
|
persona_name: str | None
|
||||||
time_created: datetime
|
time_created: datetime
|
||||||
|
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
|
|||||||
retrieved_documents: list[AbridgedSearchDoc]
|
retrieved_documents: list[AbridgedSearchDoc]
|
||||||
feedback_type: QAFeedbackType | None
|
feedback_type: QAFeedbackType | None
|
||||||
feedback_text: str | None
|
feedback_text: str | None
|
||||||
persona_name: str
|
persona_name: str | None
|
||||||
user_email: str
|
user_email: str
|
||||||
time_created: datetime
|
time_created: datetime
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
|
|||||||
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
for ind, (user_message, ai_message) in enumerate(message_pairs)
|
||||||
]
|
]
|
||||||
|
|
||||||
def to_json(self) -> dict[str, str]:
|
def to_json(self) -> dict[str, str | None]:
|
||||||
return {
|
return {
|
||||||
"chat_session_id": str(self.chat_session_id),
|
"chat_session_id": str(self.chat_session_id),
|
||||||
"message_pair_num": str(self.message_pair_num),
|
"message_pair_num": str(self.message_pair_num),
|
||||||
@ -235,7 +235,9 @@ def fetch_and_process_chat_session_history_minimal(
|
|||||||
name=chat_session.description,
|
name=chat_session.description,
|
||||||
first_user_message=first_user_message,
|
first_user_message=first_user_message,
|
||||||
first_ai_message=first_ai_message,
|
first_ai_message=first_ai_message,
|
||||||
persona_name=chat_session.persona.name,
|
persona_name=chat_session.persona.name
|
||||||
|
if chat_session.persona
|
||||||
|
else None,
|
||||||
time_created=chat_session.time_created,
|
time_created=chat_session.time_created,
|
||||||
feedback_type=feedback_type,
|
feedback_type=feedback_type,
|
||||||
)
|
)
|
||||||
@ -300,7 +302,7 @@ def snapshot_from_chat_session(
|
|||||||
for message in messages
|
for message in messages
|
||||||
if message.message_type != MessageType.SYSTEM
|
if message.message_type != MessageType.SYSTEM
|
||||||
],
|
],
|
||||||
persona_name=chat_session.persona.name,
|
persona_name=chat_session.persona.name if chat_session.persona else None,
|
||||||
time_created=chat_session.time_created,
|
time_created=chat_session.time_created,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user