diff --git a/backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py b/backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py new file mode 100644 index 0000000000..58fcf482c8 --- /dev/null +++ b/backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py @@ -0,0 +1,31 @@ +"""add nullable to persona id in Chat Session + +Revision ID: c99d76fcd298 +Revises: 5c7fdadae813 +Create Date: 2024-07-09 19:27:01.579697 + +""" + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "c99d76fcd298" +down_revision = "5c7fdadae813" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.alter_column( + "chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True + ) + + +def downgrade() -> None: + op.alter_column( + "chat_session", + "persona_id", + existing_type=sa.INTEGER(), + nullable=False, + ) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 536566c4ae..223f3b5ce4 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -675,9 +675,11 @@ def stream_chat_message_objects( db_session=db_session, selected_search_docs=selected_db_search_docs, # Deduping happens at the last step to avoid harming quality by dropping content early on - dedupe_docs=retrieval_options.dedupe_docs - if retrieval_options - else False, + dedupe_docs=( + retrieval_options.dedupe_docs + if retrieval_options + else False + ), ) yield qa_docs_response elif packet.id == SECTION_RELEVANCE_LIST_ID: @@ -786,16 +788,18 @@ def stream_chat_message_objects( if message_specific_citations else None, error=None, - tool_calls=[ - ToolCall( - tool_id=tool_name_to_tool_id[tool_result.tool_name], - tool_name=tool_result.tool_name, - tool_arguments=tool_result.tool_args, - tool_result=tool_result.tool_result, - ) - ] - if tool_result - else [], + tool_calls=( + [ + ToolCall( + tool_id=tool_name_to_tool_id[tool_result.tool_name], + tool_name=tool_result.tool_name, + tool_arguments=tool_result.tool_args, + tool_result=tool_result.tool_result, + ) + ] + if tool_result + else [] + ), ) logger.debug("Committing messages") diff --git a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py index 776c46dea3..f1c9bd077c 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_regular_answer.py @@ -5,6 +5,7 @@ from typing import cast from typing import Optional from typing import TypeVar +from fastapi import HTTPException from retry import retry from slack_sdk import WebClient from slack_sdk.models.blocks import DividerBlock @@ -153,15 +154,23 @@ def handle_regular_answer( with Session(get_sqlalchemy_engine()) as db_session: if len(new_message_request.messages) > 1: - persona = cast( - Persona, - fetch_persona_by_id( - db_session, - new_message_request.persona_id, - user=None, - get_editable=False, - ), - ) + if new_message_request.persona_config: + raise HTTPException( + status_code=403, + detail="Slack bot does not support persona config", + ) + + elif new_message_request.persona_id: + persona = cast( + Persona, + fetch_persona_by_id( + db_session, + new_message_request.persona_id, + user=None, + get_editable=False, + ), + ) + llm, _ = get_llms_for_persona(persona) # In cases of threads, split the available tokens between docs and thread context diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 8485bb4f0a..8599714ce8 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -226,7 +226,7 @@ def create_chat_session( db_session: Session, description: str, user_id: UUID | None, - persona_id: int, + persona_id: int | None, # Can be none if temporary persona is used llm_override: LLMOverride | None = None, prompt_override: PromptOverride | None = None, one_shot: bool = False, diff --git a/backend/danswer/db/llm.py b/backend/danswer/db/llm.py index a68beadc08..36d05948be 100644 --- a/backend/danswer/db/llm.py +++ b/backend/danswer/db/llm.py @@ -4,9 +4,11 @@ from sqlalchemy import select from sqlalchemy.orm import Session from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel +from danswer.db.models import DocumentSet from danswer.db.models import LLMProvider as LLMProviderModel from danswer.db.models import LLMProvider__UserGroup from danswer.db.models import SearchSettings +from danswer.db.models import Tool as ToolModel from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.server.manage.embedding.models import CloudEmbeddingProvider @@ -103,6 +105,20 @@ def fetch_existing_embedding_providers( return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all()) +def fetch_existing_doc_sets( + db_session: Session, doc_ids: list[int] +) -> list[DocumentSet]: + return list( + db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all() + ) + + +def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]: + return list( + db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all() + ) + + def fetch_existing_llm_providers( db_session: Session, user: User | None = None, diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 16a6459f38..d18d99dbc2 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -866,7 +866,9 @@ class ChatSession(Base): id: Mapped[int] = mapped_column(primary_key=True) user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) - persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id")) + persona_id: Mapped[int | None] = mapped_column( + ForeignKey("persona.id"), nullable=True + ) description: Mapped[str] = mapped_column(Text) # One-shot direct answering, currently the two types of chats are not mixed one_shot: Mapped[bool] = mapped_column(Boolean, default=False) @@ -900,7 +902,6 @@ class ChatSession(Base): prompt_override: Mapped[PromptOverride | None] = mapped_column( PydanticType(PromptOverride), nullable=True ) - time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), @@ -909,7 +910,6 @@ class ChatSession(Base): time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) - user: Mapped[User] = relationship("User", back_populates="chat_sessions") folder: Mapped["ChatFolder"] = relationship( "ChatFolder", back_populates="chat_sessions" diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index bbf45a1d9a..2064021b3a 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -563,13 +563,15 @@ def validate_persona_tools(tools: list[Tool]) -> None: ) -def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]: +def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]: """Unsafe, can fetch prompts from all users""" if not prompt_ids: return [] - prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all() + prompts = db_session.scalars( + select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False)) + ).all() - return prompts + return list(prompts) def get_prompt_by_id( diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 3f83ad1955..f051da82f1 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -26,6 +26,7 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.chat import update_search_docs_table_with_relevance from danswer.db.engine import get_session_context_manager +from danswer.db.models import Persona from danswer.db.models import User from danswer.db.persona import get_prompt_by_id from danswer.llm.answering.answer import Answer @@ -60,7 +61,7 @@ from danswer.tools.tool import ToolResponse from danswer.tools.tool_runner import ToolCallKickoff from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time - +from ee.danswer.server.query_and_chat.utils import create_temporary_persona logger = setup_logger() @@ -118,7 +119,17 @@ def stream_answer_objects( one_shot=True, danswerbot_flow=danswerbot_flow, ) - llm, fast_llm = get_llms_for_persona(persona=chat_session.persona) + + temporary_persona: Persona | None = None + if query_req.persona_config is not None: + new_persona = create_temporary_persona( + db_session=db_session, persona_config=query_req.persona_config, user=user + ) + temporary_persona = new_persona + + persona = temporary_persona if temporary_persona else chat_session.persona + + llm, fast_llm = get_llms_for_persona(persona=persona) llm_tokenizer = get_tokenizer( model_name=llm.config.model_name, @@ -153,11 +164,11 @@ def stream_answer_objects( prompt_id=query_req.prompt_id, user=None, db_session=db_session ) if prompt is None: - if not chat_session.persona.prompts: + if not persona.prompts: raise RuntimeError( "Persona does not have any prompts - this should never happen" ) - prompt = chat_session.persona.prompts[0] + prompt = persona.prompts[0] # Create the first User query message new_user_message = create_new_chat_message( @@ -174,9 +185,7 @@ def stream_answer_objects( prompt_config = PromptConfig.from_model(prompt) document_pruning_config = DocumentPruningConfig( max_chunks=int( - chat_session.persona.num_chunks - if chat_session.persona.num_chunks is not None - else default_num_chunks + persona.num_chunks if persona.num_chunks is not None else default_num_chunks ), max_tokens=max_document_tokens, ) @@ -187,16 +196,16 @@ def stream_answer_objects( evaluation_type=LLMEvaluationType.SKIP if DISABLE_LLM_DOC_RELEVANCE else query_req.evaluation_type, - persona=chat_session.persona, + persona=persona, retrieval_options=query_req.retrieval_options, prompt_config=prompt_config, llm=llm, fast_llm=fast_llm, pruning_config=document_pruning_config, + bypass_acl=bypass_acl, chunks_above=query_req.chunks_above, chunks_below=query_req.chunks_below, full_doc=query_req.full_doc, - bypass_acl=bypass_acl, ) answer_config = AnswerStyleConfig( @@ -209,13 +218,15 @@ def stream_answer_objects( question=query_msg.message, answer_style_config=answer_config, prompt_config=PromptConfig.from_model(prompt), - llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)), + llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)), single_message_history=history_str, - tools=[search_tool], - force_use_tool=ForceUseTool( - force_use=True, - tool_name=search_tool.name, - args={"query": rephrased_query}, + tools=[search_tool] if search_tool else [], + force_use_tool=( + ForceUseTool( + tool_name=search_tool.name, + args={"query": rephrased_query}, + force_use=True, + ) ), # for now, don't use tool calling for this flow, as we haven't # tested quotes with tool calling too much yet @@ -223,9 +234,7 @@ def stream_answer_objects( return_contexts=query_req.return_contexts, skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation, ) - # won't be any ImageGenerationDisplay responses since that tool is never passed in - for packet in cast(AnswerObjectIterator, answer.processed_streamed_output): # for one-shot flow, don't currently do anything with these if isinstance(packet, ToolResponse): @@ -261,6 +270,7 @@ def stream_answer_objects( applied_time_cutoff=search_response_summary.final_filters.time_cutoff, recency_bias_multiplier=search_response_summary.recency_bias_multiplier, ) + yield initial_response elif packet.id == SEARCH_DOC_CONTENT_ID: @@ -287,6 +297,7 @@ def stream_answer_objects( relevance_summary=evaluation_response, ) yield evaluation_response + else: yield packet diff --git a/backend/danswer/one_shot_answer/models.py b/backend/danswer/one_shot_answer/models.py index fceb78de7a..735fc12bbb 100644 --- a/backend/danswer/one_shot_answer/models.py +++ b/backend/danswer/one_shot_answer/models.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel from pydantic import Field from pydantic import model_validator @@ -8,6 +10,8 @@ from danswer.chat.models import DanswerQuotes from danswer.chat.models import QADocsResponse from danswer.configs.constants import MessageType from danswer.search.enums import LLMEvaluationType +from danswer.search.enums import RecencyBiasSetting +from danswer.search.enums import SearchType from danswer.search.models import ChunkContext from danswer.search.models import RerankingDetails from danswer.search.models import RetrievalDetails @@ -23,10 +27,49 @@ class ThreadMessage(BaseModel): role: MessageType = MessageType.USER +class PromptConfig(BaseModel): + name: str + description: str = "" + system_prompt: str + task_prompt: str = "" + include_citations: bool = True + datetime_aware: bool = True + + +class DocumentSetConfig(BaseModel): + id: int + + +class ToolConfig(BaseModel): + id: int + + +class PersonaConfig(BaseModel): + name: str + description: str + search_type: SearchType = SearchType.SEMANTIC + num_chunks: float | None = None + llm_relevance_filter: bool = False + llm_filter_extraction: bool = False + recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO + llm_model_provider_override: str | None = None + llm_model_version_override: str | None = None + + prompts: list[PromptConfig] = Field(default_factory=list) + prompt_ids: list[int] = Field(default_factory=list) + + document_set_ids: list[int] = Field(default_factory=list) + tools: list[ToolConfig] = Field(default_factory=list) + tool_ids: list[int] = Field(default_factory=list) + custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list) + + class DirectQARequest(ChunkContext): + persona_config: PersonaConfig | None = None + persona_id: int | None = None + messages: list[ThreadMessage] - prompt_id: int | None - persona_id: int + prompt_id: int | None = None multilingual_query_expansion: list[str] | None = None retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails) rerank_settings: RerankingDetails | None = None @@ -43,6 +86,12 @@ class DirectQARequest(ChunkContext): # If True, skips generative an AI response to the search query skip_gen_ai_answer_generation: bool = False + @model_validator(mode="after") + def check_persona_fields(self) -> "DirectQARequest": + if (self.persona_config is None) == (self.persona_id is None): + raise ValueError("Exactly one of persona_config or persona_id must be set") + return self + @model_validator(mode="after") def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest": if self.chain_of_thought and self.prompt_id is not None: diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 20ae7124fa..c7f5983417 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -164,7 +164,7 @@ def get_chat_session( chat_session_id=session_id, description=chat_session.description, persona_id=chat_session.persona_id, - persona_name=chat_session.persona.name, + persona_name=chat_session.persona.name if chat_session.persona else None, current_alternate_model=chat_session.current_alternate_model, messages=[ translate_db_message_to_chat_message_detail( diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 55d1094ea8..c9109b141c 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -136,7 +136,7 @@ class RenameChatSessionResponse(BaseModel): class ChatSessionDetails(BaseModel): id: int name: str - persona_id: int + persona_id: int | None = None time_created: str shared_status: ChatSessionSharedStatus folder_id: int | None = None @@ -196,8 +196,8 @@ class SearchSessionDetailResponse(BaseModel): class ChatSessionDetailResponse(BaseModel): chat_session_id: int description: str - persona_id: int - persona_name: str + persona_id: int | None = None + persona_name: str | None messages: list[ChatMessageDetail] time_created: datetime shared_status: ChatSessionSharedStatus diff --git a/backend/ee/danswer/server/query_and_chat/query_backend.py b/backend/ee/danswer/server/query_and_chat/query_backend.py index 2213bfca61..59e61ba12d 100644 --- a/backend/ee/danswer/server/query_and_chat/query_backend.py +++ b/backend/ee/danswer/server/query_and_chat/query_backend.py @@ -32,6 +32,7 @@ from ee.danswer.danswerbot.slack.handlers.handle_standard_answers import ( from ee.danswer.server.query_and_chat.models import DocumentSearchRequest from ee.danswer.server.query_and_chat.models import StandardAnswerRequest from ee.danswer.server.query_and_chat.models import StandardAnswerResponse +from ee.danswer.server.query_and_chat.utils import create_temporary_persona logger = setup_logger() @@ -133,12 +134,23 @@ def get_answer_with_quote( query = query_request.messages[0].message logger.notice(f"Received query for one shot answer API with quotes: {query}") - persona = get_persona_by_id( - persona_id=query_request.persona_id, - user=user, - db_session=db_session, - is_for_edit=False, - ) + if query_request.persona_config is not None: + new_persona = create_temporary_persona( + db_session=db_session, + persona_config=query_request.persona_config, + user=user, + ) + persona = new_persona + + elif query_request.persona_id is not None: + persona = get_persona_by_id( + persona_id=query_request.persona_id, + user=user, + db_session=db_session, + is_for_edit=False, + ) + else: + raise KeyError("Must provide persona ID or Persona Config") llm = get_main_llm_from_tuple( get_default_llms() if not persona else get_llms_for_persona(persona) diff --git a/backend/ee/danswer/server/query_and_chat/utils.py b/backend/ee/danswer/server/query_and_chat/utils.py new file mode 100644 index 0000000000..beb970fd1b --- /dev/null +++ b/backend/ee/danswer/server/query_and_chat/utils.py @@ -0,0 +1,83 @@ +from typing import cast + +from fastapi import HTTPException +from sqlalchemy.orm import Session + +from danswer.auth.users import is_user_admin +from danswer.db.llm import fetch_existing_doc_sets +from danswer.db.llm import fetch_existing_tools +from danswer.db.models import Persona +from danswer.db.models import Prompt +from danswer.db.models import Tool +from danswer.db.models import User +from danswer.db.persona import get_prompts_by_ids +from danswer.one_shot_answer.models import PersonaConfig +from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema + + +def create_temporary_persona( + persona_config: PersonaConfig, db_session: Session, user: User | None = None +) -> Persona: + if not is_user_admin(user): + raise HTTPException( + status_code=403, + detail="User is not authorized to create a persona in one shot queries", + ) + + """Create a temporary Persona object from the provided configuration.""" + persona = Persona( + name=persona_config.name, + description=persona_config.description, + num_chunks=persona_config.num_chunks, + llm_relevance_filter=persona_config.llm_relevance_filter, + llm_filter_extraction=persona_config.llm_filter_extraction, + recency_bias=persona_config.recency_bias, + llm_model_provider_override=persona_config.llm_model_provider_override, + llm_model_version_override=persona_config.llm_model_version_override, + ) + + if persona_config.prompts: + persona.prompts = [ + Prompt( + name=p.name, + description=p.description, + system_prompt=p.system_prompt, + task_prompt=p.task_prompt, + include_citations=p.include_citations, + datetime_aware=p.datetime_aware, + ) + for p in persona_config.prompts + ] + elif persona_config.prompt_ids: + persona.prompts = get_prompts_by_ids( + db_session=db_session, prompt_ids=persona_config.prompt_ids + ) + + persona.tools = [] + if persona_config.custom_tools_openapi: + for schema in persona_config.custom_tools_openapi: + tools = cast( + list[Tool], + build_custom_tools_from_openapi_schema(schema), + ) + persona.tools.extend(tools) + + if persona_config.tools: + tool_ids = [tool.id for tool in persona_config.tools] + persona.tools.extend( + fetch_existing_tools(db_session=db_session, tool_ids=tool_ids) + ) + + if persona_config.tool_ids: + persona.tools.extend( + fetch_existing_tools( + db_session=db_session, tool_ids=persona_config.tool_ids + ) + ) + + fetched_docs = fetch_existing_doc_sets( + db_session=db_session, doc_ids=persona_config.document_set_ids + ) + persona.document_sets = fetched_docs + + return persona diff --git a/backend/ee/danswer/server/query_history/api.py b/backend/ee/danswer/server/query_history/api.py index ed532a8560..dbdf3d8bc4 100644 --- a/backend/ee/danswer/server/query_history/api.py +++ b/backend/ee/danswer/server/query_history/api.py @@ -87,7 +87,7 @@ class ChatSessionMinimal(BaseModel): name: str | None first_user_message: str first_ai_message: str - persona_name: str + persona_name: str | None time_created: datetime feedback_type: QAFeedbackType | Literal["mixed"] | None @@ -97,7 +97,7 @@ class ChatSessionSnapshot(BaseModel): user_email: str name: str | None messages: list[MessageSnapshot] - persona_name: str + persona_name: str | None time_created: datetime @@ -111,7 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel): retrieved_documents: list[AbridgedSearchDoc] feedback_type: QAFeedbackType | None feedback_text: str | None - persona_name: str + persona_name: str | None user_email: str time_created: datetime @@ -145,7 +145,7 @@ class QuestionAnswerPairSnapshot(BaseModel): for ind, (user_message, ai_message) in enumerate(message_pairs) ] - def to_json(self) -> dict[str, str]: + def to_json(self) -> dict[str, str | None]: return { "chat_session_id": str(self.chat_session_id), "message_pair_num": str(self.message_pair_num), @@ -235,7 +235,9 @@ def fetch_and_process_chat_session_history_minimal( name=chat_session.description, first_user_message=first_user_message, first_ai_message=first_ai_message, - persona_name=chat_session.persona.name, + persona_name=chat_session.persona.name + if chat_session.persona + else None, time_created=chat_session.time_created, feedback_type=feedback_type, ) @@ -300,7 +302,7 @@ def snapshot_from_chat_session( for message in messages if message.message_type != MessageType.SYSTEM ], - persona_name=chat_session.persona.name, + persona_name=chat_session.persona.name if chat_session.persona else None, time_created=chat_session.time_created, )