Add image upload capabilities

This commit is contained in:
Weves 2024-04-29 01:18:47 -07:00 committed by Chris Weaver
parent 350e548b2d
commit 5b93e786ad
33 changed files with 992 additions and 363 deletions

View File

@ -0,0 +1,27 @@
"""Add files to ChatMessage
Revision ID: ef7da92f7213
Revises: 401c1ac29467
Create Date: 2024-04-28 16:59:33.199153
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "ef7da92f7213"
down_revision = "401c1ac29467"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column("files", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "files")

View File

@ -30,6 +30,8 @@ from danswer.db.engine import get_session_context_manager
from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import User from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.utils import load_all_chat_files
from danswer.llm.answering.answer import Answer from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig from danswer.llm.answering.models import CitationConfig
@ -174,6 +176,10 @@ def stream_chat_message_objects(
message=message_text, message=message_text,
token_count=len(llm_tokenizer_encode_func(message_text)), token_count=len(llm_tokenizer_encode_func(message_text)),
message_type=MessageType.USER, message_type=MessageType.USER,
files=[
{"id": str(file_id), "type": ChatFileType.IMAGE}
for file_id in new_msg_req.file_ids
],
db_session=db_session, db_session=db_session,
commit=False, commit=False,
) )
@ -202,9 +208,20 @@ def stream_chat_message_objects(
"when the last message is not a user message." "when the last message is not a user message."
) )
# load all files needed for this chat chain in memory
files = load_all_chat_files(history_msgs, new_msg_req.file_ids, db_session)
latest_query_files = [
file for file in files if file.file_id in new_msg_req.file_ids
]
run_search = False run_search = False
# Retrieval options are only None if reference_doc_ids are provided # Retrieval options are only None if reference_doc_ids are provided
if retrieval_options is not None and persona.num_chunks != 0: # Also don't perform search if the user uploaded at least one file - just use the files
if (
retrieval_options is not None
and persona.num_chunks != 0
and not new_msg_req.file_ids
):
if retrieval_options.run_search == OptionalSearchSetting.ALWAYS: if retrieval_options.run_search == OptionalSearchSetting.ALWAYS:
run_search = True run_search = True
elif retrieval_options.run_search == OptionalSearchSetting.NEVER: elif retrieval_options.run_search == OptionalSearchSetting.NEVER:
@ -360,6 +377,7 @@ def stream_chat_message_objects(
answer = Answer( answer = Answer(
question=final_msg.message, question=final_msg.message,
docs=llm_docs, docs=llm_docs,
latest_query_files=latest_query_files,
answer_style_config=AnswerStyleConfig( answer_style_config=AnswerStyleConfig(
citation_config=CitationConfig( citation_config=CitationConfig(
all_docs_useful=reference_db_search_docs is not None all_docs_useful=reference_db_search_docs is not None
@ -380,7 +398,7 @@ def stream_chat_message_objects(
), ),
doc_relevance_list=llm_relevance_list, doc_relevance_list=llm_relevance_list,
message_history=[ message_history=[
PreviousMessage.from_chat_message(msg) for msg in history_msgs PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
], ],
) )
# generator will not include quotes, so we can cast # generator will not include quotes, so we can cast

View File

@ -23,7 +23,7 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.file_store import get_default_file_store from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()

View File

@ -17,7 +17,7 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.file_store import get_default_file_store from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()

View File

@ -30,6 +30,7 @@ from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import StarterMessage from danswer.db.models import StarterMessage
from danswer.db.models import User from danswer.db.models import User
from danswer.db.models import User__UserGroup from danswer.db.models import User__UserGroup
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting from danswer.search.enums import RecencyBiasSetting
@ -256,6 +257,7 @@ def create_new_chat_message(
token_count: int, token_count: int,
message_type: MessageType, message_type: MessageType,
db_session: Session, db_session: Session,
files: list[FileDescriptor] | None = None,
rephrased_query: str | None = None, rephrased_query: str | None = None,
error: str | None = None, error: str | None = None,
reference_docs: list[DBSearchDoc] | None = None, reference_docs: list[DBSearchDoc] | None = None,
@ -273,6 +275,7 @@ def create_new_chat_message(
token_count=token_count, token_count=token_count,
message_type=message_type, message_type=message_type,
citations=citations, citations=citations,
files=files,
error=error, error=error,
) )
@ -819,6 +822,7 @@ def translate_db_message_to_chat_message_detail(
message_type=chat_message.message_type, message_type=chat_message.message_type,
time_sent=chat_message.time_sent, time_sent=chat_message.time_sent,
citations=chat_message.citations, citations=chat_message.citations,
files=chat_message.files or [],
) )
return chat_msg_detail return chat_msg_detail

View File

@ -42,6 +42,7 @@ from danswer.db.enums import IndexModelStatus
from danswer.db.enums import TaskStatus from danswer.db.enums import TaskStatus
from danswer.db.pydantic_type import PydanticType from danswer.db.pydantic_type import PydanticType
from danswer.dynamic_configs.interface import JSON_ro from danswer.dynamic_configs.interface import JSON_ro
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting from danswer.search.enums import RecencyBiasSetting
@ -629,6 +630,11 @@ class ChatMessage(Base):
) )
# Maps the citation numbers to a SearchDoc id # Maps the citation numbers to a SearchDoc id
citations: Mapped[dict[int, int]] = mapped_column(postgresql.JSONB(), nullable=True) citations: Mapped[dict[int, int]] = mapped_column(postgresql.JSONB(), nullable=True)
# files associated with this message (e.g. images uploaded by the user that the
# user is asking a question of)
files: Mapped[list[FileDescriptor] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Only applies for LLM # Only applies for LLM
error: Mapped[str | None] = mapped_column(Text, nullable=True) error: Mapped[str | None] = mapped_column(Text, nullable=True)
time_sent: Mapped[datetime.datetime] = mapped_column( time_sent: Mapped[datetime.datetime] = mapped_column(

View File

@ -0,0 +1,33 @@
import base64
from enum import Enum
from typing import TypedDict
from uuid import UUID
from pydantic import BaseModel
class ChatFileType(str, Enum):
IMAGE = "image"
class FileDescriptor(TypedDict):
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
in Postgres"""
id: str
type: ChatFileType
class InMemoryChatFile(BaseModel):
file_id: UUID
content: bytes
file_type: ChatFileType = ChatFileType.IMAGE
def to_base64(self) -> str:
return base64.b64encode(self.content).decode()
def to_file_descriptor(self) -> FileDescriptor:
return {
"id": str(self.file_id),
"type": self.file_type,
}

View File

@ -0,0 +1,40 @@
from typing import cast
from uuid import UUID
from sqlalchemy.orm import Session
from danswer.db.models import ChatMessage
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import InMemoryChatFile
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def build_chat_file_name(file_id: UUID | str) -> str:
return f"chat__{file_id}"
def load_chat_file(file_id: UUID, db_session: Session) -> InMemoryChatFile:
file_io = get_default_file_store(db_session).read_file(
build_chat_file_name(file_id), mode="b"
)
return InMemoryChatFile(file_id=file_id, content=file_io.read())
def load_all_chat_files(
chat_messages: list[ChatMessage], new_file_ids: list[UUID], db_session: Session
) -> list[InMemoryChatFile]:
file_ids_for_history = []
for chat_message in chat_messages:
if chat_message.files:
file_ids_for_history.extend([file["id"] for file in chat_message.files])
files = cast(
list[InMemoryChatFile],
run_functions_tuples_in_parallel(
[
(load_chat_file, (file_id, db_session))
for file_id in new_file_ids + file_ids_for_history
]
),
)
return files

View File

@ -9,7 +9,7 @@ from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.chat_configs import QA_TIMEOUT from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.doc_pruning import prune_documents from danswer.llm.answering.doc_pruning import prune_documents
from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PreviousMessage
@ -58,7 +58,9 @@ class Answer:
doc_relevance_list: list[bool] | None = None, doc_relevance_list: list[bool] | None = None,
message_history: list[PreviousMessage] | None = None, message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None, single_message_history: str | None = None,
timeout: int = QA_TIMEOUT, # newly passed in files to include as part of this question
latest_query_files: list[InMemoryChatFile] | None = None,
files: list[InMemoryChatFile] | None = None,
) -> None: ) -> None:
if single_message_history and message_history: if single_message_history and message_history:
raise ValueError( raise ValueError(
@ -67,6 +69,10 @@ class Answer:
self.question = question self.question = question
self.docs = docs self.docs = docs
self.latest_query_files = latest_query_files or []
self.file_id_to_file = {file.file_id: file for file in (files or [])}
self.doc_relevance_list = doc_relevance_list self.doc_relevance_list = doc_relevance_list
self.message_history = message_history or [] self.message_history = message_history or []
# used for QA flow where we only want to send a single message # used for QA flow where we only want to send a single message
@ -112,11 +118,15 @@ class Answer:
llm_config=self.llm.config, llm_config=self.llm.config,
prompt_config=self.prompt_config, prompt_config=self.prompt_config,
context_docs=self.pruned_docs, context_docs=self.pruned_docs,
latest_query_files=self.latest_query_files,
all_doc_useful=self.answer_style_config.citation_config.all_docs_useful, all_doc_useful=self.answer_style_config.citation_config.all_docs_useful,
llm_tokenizer_encode_func=self.llm_tokenizer.encode, llm_tokenizer_encode_func=self.llm_tokenizer.encode,
history_message=self.single_message_history or "", history_message=self.single_message_history or "",
) )
elif self.answer_style_config.quotes_config: elif self.answer_style_config.quotes_config:
# NOTE: quotes prompt doesn't currently support files
# this is okay for now, since the search UI (which uses this)
# doesn't support image upload
self._final_prompt = build_quotes_prompt( self._final_prompt = build_quotes_prompt(
question=self.question, question=self.question,
context_docs=self.pruned_docs, context_docs=self.pruned_docs,

View File

@ -9,6 +9,7 @@ from pydantic import root_validator
from danswer.chat.models import AnswerQuestionStreamReturn from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.override_models import PromptOverride from danswer.llm.override_models import PromptOverride
if TYPE_CHECKING: if TYPE_CHECKING:
@ -25,13 +26,24 @@ class PreviousMessage(BaseModel):
message: str message: str
token_count: int token_count: int
message_type: MessageType message_type: MessageType
files: list[InMemoryChatFile]
@classmethod @classmethod
def from_chat_message(cls, chat_message: "ChatMessage") -> "PreviousMessage": def from_chat_message(
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
) -> "PreviousMessage":
message_file_ids = (
[file["id"] for file in chat_message.files] if chat_message.files else []
)
return cls( return cls(
message=chat_message.message, message=chat_message.message,
token_count=chat_message.token_count, token_count=chat_message.token_count,
message_type=chat_message.message_type, message_type=chat_message.message_type,
files=[
file
for file in available_files
if str(file.file_id) in message_file_ids
],
) )

View File

@ -11,10 +11,12 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_default_prompt from danswer.db.chat import get_default_prompt
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llm_for_persona from danswer.llm.factory import get_llm_for_persona
from danswer.llm.interfaces import LLMConfig from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_max_input_tokens from danswer.llm.utils import get_max_input_tokens
@ -59,7 +61,7 @@ def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
return last_ind return last_ind
def drop_messages_history_overflow( def _drop_messages_history_overflow(
system_msg: BaseMessage | None, system_msg: BaseMessage | None,
system_token_count: int, system_token_count: int,
history_msgs: list[BaseMessage], history_msgs: list[BaseMessage],
@ -171,7 +173,7 @@ def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int:
@lru_cache() @lru_cache()
def build_system_message( def _build_system_message(
prompt_config: PromptConfig, prompt_config: PromptConfig,
context_exists: bool, context_exists: bool,
llm_tokenizer_encode_func: Callable, llm_tokenizer_encode_func: Callable,
@ -201,10 +203,11 @@ def build_system_message(
return system_msg, token_count return system_msg, token_count
def build_user_message( def _build_user_message(
question: str, question: str,
prompt_config: PromptConfig, prompt_config: PromptConfig,
context_docs: list[LlmDoc] | list[InferenceChunk], context_docs: list[LlmDoc] | list[InferenceChunk],
files: list[InMemoryChatFile],
all_doc_useful: bool, all_doc_useful: bool,
history_message: str, history_message: str,
) -> tuple[HumanMessage, int]: ) -> tuple[HumanMessage, int]:
@ -222,7 +225,11 @@ def build_user_message(
) )
user_prompt = user_prompt.strip() user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer_encode_func(user_prompt)) token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt) user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files)
if files
else user_prompt
)
return user_msg, token_count return user_msg, token_count
context_docs_str = build_complete_context_str(context_docs) context_docs_str = build_complete_context_str(context_docs)
@ -240,7 +247,9 @@ def build_user_message(
user_prompt = user_prompt.strip() user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer_encode_func(user_prompt)) token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt) user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
)
return user_msg, token_count return user_msg, token_count
@ -251,13 +260,14 @@ def build_citations_prompt(
prompt_config: PromptConfig, prompt_config: PromptConfig,
llm_config: LLMConfig, llm_config: LLMConfig,
context_docs: list[LlmDoc] | list[InferenceChunk], context_docs: list[LlmDoc] | list[InferenceChunk],
latest_query_files: list[InMemoryChatFile],
all_doc_useful: bool, all_doc_useful: bool,
history_message: str, history_message: str,
llm_tokenizer_encode_func: Callable, llm_tokenizer_encode_func: Callable,
) -> list[BaseMessage]: ) -> list[BaseMessage]:
context_exists = len(context_docs) > 0 context_exists = len(context_docs) > 0
system_message_or_none, system_tokens = build_system_message( system_message_or_none, system_tokens = _build_system_message(
prompt_config=prompt_config, prompt_config=prompt_config,
context_exists=context_exists, context_exists=context_exists,
llm_tokenizer_encode_func=llm_tokenizer_encode_func, llm_tokenizer_encode_func=llm_tokenizer_encode_func,
@ -269,15 +279,16 @@ def build_citations_prompt(
# Be sure the context_docs passed to build_chat_user_message # Be sure the context_docs passed to build_chat_user_message
# Is the same as passed in later for extracting citations # Is the same as passed in later for extracting citations
user_message, user_tokens = build_user_message( user_message, user_tokens = _build_user_message(
question=question, question=question,
prompt_config=prompt_config, prompt_config=prompt_config,
context_docs=context_docs, context_docs=context_docs,
files=latest_query_files,
all_doc_useful=all_doc_useful, all_doc_useful=all_doc_useful,
history_message=history_message, history_message=history_message,
) )
final_prompt_msgs = drop_messages_history_overflow( final_prompt_msgs = _drop_messages_history_overflow(
system_msg=system_message_or_none, system_msg=system_message_or_none,
system_token_count=system_tokens, system_token_count=system_tokens,
history_msgs=history_basemessages, history_msgs=history_basemessages,

View File

@ -18,8 +18,10 @@ class WellKnownLLMProviderDescriptor(BaseModel):
OPENAI_PROVIDER_NAME = "openai" OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [ OPEN_AI_MODEL_NAMES = [
"gpt-4", "gpt-4",
"gpt-4-turbo",
"gpt-4-turbo-preview", "gpt-4-turbo-preview",
"gpt-4-1106-preview", "gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-32k", "gpt-4-32k",
"gpt-4-0613", "gpt-4-0613",
"gpt-4-32k-0613", "gpt-4-32k-0613",

View File

@ -2,6 +2,7 @@ from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from copy import copy from copy import copy
from typing import Any from typing import Any
from typing import cast
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union from typing import Union
@ -24,6 +25,7 @@ from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.search.models import InferenceChunk from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@ -85,12 +87,17 @@ def tokenizer_trim_chunks(
def translate_danswer_msg_to_langchain( def translate_danswer_msg_to_langchain(
msg: Union[ChatMessage, "PreviousMessage"], msg: Union[ChatMessage, "PreviousMessage"],
) -> BaseMessage: ) -> BaseMessage:
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now
files = [] if isinstance(msg, ChatMessage) else msg.files
content = build_content_with_imgs(msg.message, files)
if msg.message_type == MessageType.SYSTEM: if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history") raise ValueError("System messages are not currently part of history")
if msg.message_type == MessageType.ASSISTANT: if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=msg.message) return AIMessage(content=content)
if msg.message_type == MessageType.USER: if msg.message_type == MessageType.USER:
return HumanMessage(content=msg.message) return HumanMessage(content=content)
raise ValueError(f"New message type {msg.message_type} not handled") raise ValueError(f"New message type {msg.message_type} not handled")
@ -107,6 +114,33 @@ def translate_history_to_basemessages(
return history_basemessages, history_token_counts return history_basemessages, history_token_counts
def build_content_with_imgs(
message: str, files: list[InMemoryChatFile]
) -> str | list[str | dict]: # matching Langchain's BaseMessage content type
if not files:
return message
return cast(
list[str | dict],
[
{
"type": "text",
"text": message,
},
]
+ [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{file.to_base64()}",
},
}
for file in files
if file.file_type == "image"
],
)
def dict_based_prompt_to_langchain_prompt( def dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]] messages: list[dict[str, str]]
) -> list[BaseMessage]: ) -> list[BaseMessage]:

View File

@ -215,7 +215,6 @@ def stream_answer_objects(
llm=get_llm_for_persona(persona=chat_session.persona), llm=get_llm_for_persona(persona=chat_session.persona),
doc_relevance_list=search_pipeline.section_relevance_list, doc_relevance_list=search_pipeline.section_relevance_list,
single_message_history=history_str, single_message_history=history_str,
timeout=timeout,
) )
yield from answer.processed_streamed_output yield from answer.processed_streamed_output

View File

@ -59,7 +59,6 @@ from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_document_cnts_for_cc_pairs from danswer.db.document import get_document_cnts_for_cc_pairs
from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.file_store import get_default_file_store
from danswer.db.index_attempt import cancel_indexing_attempts_for_connector from danswer.db.index_attempt import cancel_indexing_attempts_for_connector
from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import create_index_attempt
@ -67,6 +66,7 @@ from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempts from danswer.db.index_attempt import get_latest_index_attempts
from danswer.db.models import User from danswer.db.models import User
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.server.documents.models import AuthStatus from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl from danswer.server.documents.models import AuthUrl
from danswer.server.documents.models import ConnectorBase from danswer.server.documents.models import ConnectorBase

View File

@ -24,13 +24,13 @@ from danswer.db.engine import get_session
from danswer.db.feedback import fetch_docs_ranked_by_boost from danswer.db.feedback import fetch_docs_ranked_by_boost
from danswer.db.feedback import update_document_boost from danswer.db.feedback import update_document_boost
from danswer.db.feedback import update_document_hidden from danswer.db.feedback import update_document_hidden
from danswer.db.file_store import get_default_file_store
from danswer.db.index_attempt import cancel_indexing_attempts_for_connector from danswer.db.index_attempt import cancel_indexing_attempts_for_connector
from danswer.db.models import User from danswer.db.models import User
from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.utils import test_llm from danswer.llm.utils import test_llm
from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairIdentifier

View File

@ -1,6 +1,10 @@
import uuid
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import Depends from fastapi import Depends
from fastapi import HTTPException from fastapi import HTTPException
from fastapi import Response
from fastapi import UploadFile
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -28,6 +32,8 @@ from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.models import User from danswer.db.models import User
from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.utils import build_chat_file_name
from danswer.llm.answering.prompts.citations_prompt import ( from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona, compute_max_document_tokens_for_persona,
) )
@ -404,3 +410,50 @@ def seed_chat(
return ChatSeedResponse( return ChatSeedResponse(
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}" redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}"
) )
"""File upload"""
@router.post("/file")
def upload_files_for_chat(
files: list[UploadFile],
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> dict[str, list[uuid.UUID]]:
for file in files:
if file.content_type not in ("image/jpeg", "image/png", "image/webp"):
raise HTTPException(
status_code=400,
detail="Only .jpg, .jpeg, .png, and .webp files are currently supported",
)
if file.size and file.size > 20 * 1024 * 1024:
raise HTTPException(
status_code=400,
detail="File size must be less than 20MB",
)
file_store = get_default_file_store(db_session)
file_ids = []
for file in files:
file_id = uuid.uuid4()
file_name = build_chat_file_name(file_id)
file_store.save_file(file_name=file_name, content=file.file)
file_ids.append(file_id)
return {"file_ids": file_ids}
@router.get("/file/{file_id}")
def fetch_chat_file(
file_id: str,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> Response:
file_store = get_default_file_store(db_session)
file_io = file_store.read_file(build_chat_file_name(file_id), mode="b")
# NOTE: specifying "image/jpeg" here, but it still works for pngs
# TODO: do this properly
return Response(content=file_io.read(), media_type="image/jpeg")

View File

@ -1,5 +1,6 @@
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
from pydantic import root_validator from pydantic import root_validator
@ -9,6 +10,7 @@ from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType from danswer.configs.constants import SearchFeedbackType
from danswer.db.enums import ChatSessionSharedStatus from danswer.db.enums import ChatSessionSharedStatus
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride from danswer.llm.override_models import PromptOverride
from danswer.search.models import BaseFilters from danswer.search.models import BaseFilters
@ -81,6 +83,8 @@ class CreateChatMessageRequest(ChunkContext):
parent_message_id: int | None parent_message_id: int | None
# New message contents # New message contents
message: str message: str
# file's that we should attach to this message
file_ids: list[UUID]
# If no prompt provided, uses the largest prompt of the chat session # If no prompt provided, uses the largest prompt of the chat session
# but really this should be explicitly specified, only in the simplified APIs is this inferred # but really this should be explicitly specified, only in the simplified APIs is this inferred
# Use prompt_id 0 to use the system default prompt which is Answer-Question # Use prompt_id 0 to use the system default prompt which is Answer-Question
@ -171,6 +175,7 @@ class ChatMessageDetail(BaseModel):
time_sent: datetime time_sent: datetime
# Dict mapping citation number to db_doc_id # Dict mapping citation number to db_doc_id
citations: dict[int, int] | None citations: dict[int, int] | None
files: list[FileDescriptor]
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore initial_dict = super().dict(*args, **kwargs) # type: ignore

37
web/package-lock.json generated
View File

@ -12,6 +12,7 @@
"@dnd-kit/modifiers": "^7.0.0", "@dnd-kit/modifiers": "^7.0.0",
"@dnd-kit/sortable": "^8.0.0", "@dnd-kit/sortable": "^8.0.0",
"@phosphor-icons/react": "^2.0.8", "@phosphor-icons/react": "^2.0.8",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-popover": "^1.0.7", "@radix-ui/react-popover": "^1.0.7",
"@tremor/react": "^3.9.2", "@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3", "@types/js-cookie": "^3.0.3",
@ -1221,6 +1222,42 @@
} }
} }
}, },
"node_modules/@radix-ui/react-dialog": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz",
"integrity": "sha512-GjWJX/AUpB703eEBanuBnIWdIXg6NvJFCXcNlSZk4xdszCdhrJgBoUd1cGk67vFO+WdA2pfI/plOpqz/5GUP6Q==",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/primitive": "1.0.1",
"@radix-ui/react-compose-refs": "1.0.1",
"@radix-ui/react-context": "1.0.1",
"@radix-ui/react-dismissable-layer": "1.0.5",
"@radix-ui/react-focus-guards": "1.0.1",
"@radix-ui/react-focus-scope": "1.0.4",
"@radix-ui/react-id": "1.0.1",
"@radix-ui/react-portal": "1.0.4",
"@radix-ui/react-presence": "1.0.1",
"@radix-ui/react-primitive": "1.0.3",
"@radix-ui/react-slot": "1.0.2",
"@radix-ui/react-use-controllable-state": "1.0.1",
"aria-hidden": "^1.1.1",
"react-remove-scroll": "2.5.5"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0",
"react-dom": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-dismissable-layer": { "node_modules/@radix-ui/react-dismissable-layer": {
"version": "1.0.5", "version": "1.0.5",
"resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.5.tgz", "resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.5.tgz",

View File

@ -13,6 +13,7 @@
"@dnd-kit/modifiers": "^7.0.0", "@dnd-kit/modifiers": "^7.0.0",
"@dnd-kit/sortable": "^8.0.0", "@dnd-kit/sortable": "^8.0.0",
"@phosphor-icons/react": "^2.0.8", "@phosphor-icons/react": "^2.0.8",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-popover": "^1.0.7", "@radix-ui/react-popover": "^1.0.7",
"@tremor/react": "^3.9.2", "@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3", "@types/js-cookie": "^3.0.3",

View File

@ -27,3 +27,11 @@ export interface FullLLMProvider extends LLMProvider {
is_default_provider: boolean | null; is_default_provider: boolean | null;
model_names: string[]; model_names: string[];
} }
export interface LLMProviderDescriptor {
name: string;
model_names: string[];
default_model_name: string;
fast_default_model_name: string | null;
is_default_provider: boolean | null;
}

View File

@ -7,6 +7,7 @@ import {
ChatSession, ChatSession,
ChatSessionSharedStatus, ChatSessionSharedStatus,
DocumentsResponse, DocumentsResponse,
FileDescriptor,
Message, Message,
RetrievalType, RetrievalType,
StreamingError, StreamingError,
@ -30,6 +31,7 @@ import {
personaIncludesRetrieval, personaIncludesRetrieval,
processRawChatHistory, processRawChatHistory,
sendMessage, sendMessage,
uploadFilesForChat,
} from "./lib"; } from "./lib";
import { useContext, useEffect, useRef, useState } from "react"; import { useContext, useEffect, useRef, useState } from "react";
import { usePopup } from "@/components/admin/connectors/Popup"; import { usePopup } from "@/components/admin/connectors/Popup";
@ -56,16 +58,21 @@ import { AnswerPiecePacket, DanswerDocument } from "@/lib/search/interfaces";
import { buildFilters } from "@/lib/search/utils"; import { buildFilters } from "@/lib/search/utils";
import { Tabs } from "./sessionSidebar/constants"; import { Tabs } from "./sessionSidebar/constants";
import { SettingsContext } from "@/components/settings/SettingsProvider"; import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
import { LLMProviderDescriptor } from "../admin/models/llm/interfaces";
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
import { InputBarPreviewImage } from "./images/InputBarPreviewImage";
const MAX_INPUT_HEIGHT = 200; const MAX_INPUT_HEIGHT = 200;
export function ChatLayout({ export function ChatPage({
user, user,
chatSessions, chatSessions,
availableSources, availableSources,
availableDocumentSets, availableDocumentSets,
availablePersonas, availablePersonas,
availableTags, availableTags,
llmProviders,
defaultSelectedPersonaId, defaultSelectedPersonaId,
documentSidebarInitialWidth, documentSidebarInitialWidth,
defaultSidebarTab, defaultSidebarTab,
@ -76,6 +83,7 @@ export function ChatLayout({
availableDocumentSets: DocumentSet[]; availableDocumentSets: DocumentSet[];
availablePersonas: Persona[]; availablePersonas: Persona[];
availableTags: Tag[]; availableTags: Tag[];
llmProviders: LLMProviderDescriptor[];
defaultSelectedPersonaId?: number; // what persona to default to defaultSelectedPersonaId?: number; // what persona to default to
documentSidebarInitialWidth?: number; documentSidebarInitialWidth?: number;
defaultSidebarTab?: Tabs; defaultSidebarTab?: Tabs;
@ -123,6 +131,9 @@ export function ChatLayout({
filterManager.setSelectedSources([]); filterManager.setSelectedSources([]);
filterManager.setSelectedTags([]); filterManager.setSelectedTags([]);
filterManager.setTimeRange(null); filterManager.setTimeRange(null);
// remove uploaded files
setCurrentMessageFileIds([]);
if (isStreaming) { if (isStreaming) {
setIsCancelled(true); setIsCancelled(true);
} }
@ -209,6 +220,11 @@ export function ChatLayout({
const [messageHistory, setMessageHistory] = useState<Message[]>([]); const [messageHistory, setMessageHistory] = useState<Message[]>([]);
const [isStreaming, setIsStreaming] = useState(false); const [isStreaming, setIsStreaming] = useState(false);
// uploaded files
const [currentMessageFileIds, setCurrentMessageFileIds] = useState<string[]>(
[]
);
// for document display // for document display
// NOTE: -1 is a special designation that means the latest AI message // NOTE: -1 is a special designation that means the latest AI message
const [selectedMessageForDocDisplay, setSelectedMessageForDocDisplay] = const [selectedMessageForDocDisplay, setSelectedMessageForDocDisplay] =
@ -404,15 +420,21 @@ export function ChatLayout({
messageToResendIndex !== null messageToResendIndex !== null
? messageHistory.slice(0, messageToResendIndex) ? messageHistory.slice(0, messageToResendIndex)
: messageHistory; : messageHistory;
const currFiles = currentMessageFileIds.map((id) => ({
id,
type: "image",
})) as FileDescriptor[];
setMessageHistory([ setMessageHistory([
...currMessageHistory, ...currMessageHistory,
{ {
messageId: 0, messageId: 0,
message: currMessage, message: currMessage,
type: "user", type: "user",
files: currFiles,
}, },
]); ]);
setMessage(""); setMessage("");
setCurrentMessageFileIds([]);
setIsStreaming(true); setIsStreaming(true);
let answer = ""; let answer = "";
@ -429,6 +451,7 @@ export function ChatLayout({
getLastSuccessfulMessageId(currMessageHistory); getLastSuccessfulMessageId(currMessageHistory);
for await (const packetBunch of sendMessage({ for await (const packetBunch of sendMessage({
message: currMessage, message: currMessage,
fileIds: currentMessageFileIds,
parentMessageId: lastSuccessfulMessageId, parentMessageId: lastSuccessfulMessageId,
chatSessionId: currChatSessionId, chatSessionId: currChatSessionId,
promptId: livePersona?.prompts[0]?.id || 0, promptId: livePersona?.prompts[0]?.id || 0,
@ -479,6 +502,7 @@ export function ChatLayout({
messageId: finalMessage?.parent_message || null, messageId: finalMessage?.parent_message || null,
message: currMessage, message: currMessage,
type: "user", type: "user",
files: currFiles,
}, },
{ {
messageId: finalMessage?.message_id || null, messageId: finalMessage?.message_id || null,
@ -488,6 +512,7 @@ export function ChatLayout({
query: finalMessage?.rephrased_query || query, query: finalMessage?.rephrased_query || query,
documents: finalMessage?.context_docs?.top_documents || documents, documents: finalMessage?.context_docs?.top_documents || documents,
citations: finalMessage?.citations || {}, citations: finalMessage?.citations || {},
files: finalMessage?.files || [],
}, },
]); ]);
if (isCancelledRef.current) { if (isCancelledRef.current) {
@ -503,11 +528,13 @@ export function ChatLayout({
messageId: null, messageId: null,
message: currMessage, message: currMessage,
type: "user", type: "user",
files: currFiles,
}, },
{ {
messageId: null, messageId: null,
message: errorMsg, message: errorMsg,
type: "error", type: "error",
files: [],
}, },
]); ]);
} }
@ -570,7 +597,10 @@ export function ChatLayout({
}; };
const onPersonaChange = (persona: Persona | null) => { const onPersonaChange = (persona: Persona | null) => {
if (persona) { if (persona && persona.id !== livePersona.id) {
// remove uploaded files
setCurrentMessageFileIds([]);
setSelectedPersona(persona); setSelectedPersona(persona);
textareaRef.current?.focus(); textareaRef.current?.focus();
router.push(buildChatUrl(searchParams, null, persona.id)); router.push(buildChatUrl(searchParams, null, persona.id));
@ -638,220 +668,256 @@ export function ChatLayout({
)} )}
{documentSidebarInitialWidth !== undefined ? ( {documentSidebarInitialWidth !== undefined ? (
<> <Dropzone
<div onDrop={(acceptedFiles) => {
className={`w-full sm:relative h-screen ${ uploadFilesForChat(acceptedFiles).then(([fileIds, error]) => {
retrievalDisabled ? "pb-[111px]" : "pb-[140px]" if (error) {
}`} setPopup({
> type: "error",
<div message: error,
className={`w-full h-full ${HEADER_PADDING} flex flex-col overflow-y-auto overflow-x-hidden relative`} });
ref={scrollableDivRef} } else {
> const newFileIds = [...currentMessageFileIds, ...fileIds];
{livePersona && ( setCurrentMessageFileIds(newFileIds);
<div className="sticky top-0 left-80 z-10 w-full bg-background/90 flex"> }
<div className="ml-2 p-1 rounded mt-2 w-fit"> });
<ChatPersonaSelector }}
personas={availablePersonas} noClick
selectedPersonaId={livePersona.id} disabled={
onPersonaChange={onPersonaChange} !checkLLMSupportsImageInput(
/> ...getFinalLLM(llmProviders, livePersona)
</div> )
}
onDragLeave={() => console.log("buh")}
onDragEnter={() => console.log("floppa")}
>
{({ getRootProps }) => (
<>
<div
className={`w-full sm:relative h-screen ${
retrievalDisabled ? "pb-[111px]" : "pb-[140px]"
}`}
{...getRootProps()}
>
{/* <input {...getInputProps()} /> */}
<div
className={`w-full h-full ${HEADER_PADDING} flex flex-col overflow-y-auto overflow-x-hidden relative`}
ref={scrollableDivRef}
>
{livePersona && (
<div className="sticky top-0 left-80 z-10 w-full bg-background/90 flex">
<div className="ml-2 p-1 rounded mt-2 w-fit">
<ChatPersonaSelector
personas={availablePersonas}
selectedPersonaId={livePersona.id}
onPersonaChange={onPersonaChange}
/>
</div>
{chatSessionId !== null && ( {chatSessionId !== null && (
<div <div
onClick={() => setSharingModalVisible(true)} onClick={() => setSharingModalVisible(true)}
className="ml-auto mr-6 my-auto border-border border p-2 rounded cursor-pointer hover:bg-hover-light" className="ml-auto mr-6 my-auto border-border border p-2 rounded cursor-pointer hover:bg-hover-light"
> >
<FiShare2 /> <FiShare2 />
</div>
)}
</div> </div>
)} )}
</div>
)}
{messageHistory.length === 0 && {messageHistory.length === 0 &&
!isFetchingChatMessages && !isFetchingChatMessages &&
!isStreaming && ( !isStreaming && (
<ChatIntro <ChatIntro
availableSources={finalAvailableSources} availableSources={finalAvailableSources}
availablePersonas={availablePersonas} availablePersonas={availablePersonas}
selectedPersona={selectedPersona} selectedPersona={selectedPersona}
handlePersonaSelect={(persona) => { handlePersonaSelect={(persona) => {
setSelectedPersona(persona); setSelectedPersona(persona);
textareaRef.current?.focus(); textareaRef.current?.focus();
router.push( router.push(
buildChatUrl(searchParams, null, persona.id) buildChatUrl(searchParams, null, persona.id)
); );
}} }}
/> />
)} )}
<div <div
className={ className={
"mt-4 pt-12 sm:pt-0 mx-8" + "mt-4 pt-12 sm:pt-0 mx-8" +
(hasPerformedInitialScroll ? "" : " invisible") (hasPerformedInitialScroll ? "" : " invisible")
} }
> >
{messageHistory.map((message, i) => { {messageHistory.map((message, i) => {
if (message.type === "user") { if (message.type === "user") {
return ( return (
<div key={i}> <div key={i}>
<HumanMessage content={message.message} /> <HumanMessage
</div> content={message.message}
); files={message.files}
} else if (message.type === "assistant") { />
const isShowingRetrieved = </div>
(selectedMessageForDocDisplay !== null && );
selectedMessageForDocDisplay === } else if (message.type === "assistant") {
message.messageId) || const isShowingRetrieved =
(selectedMessageForDocDisplay === -1 && (selectedMessageForDocDisplay !== null &&
i === messageHistory.length - 1); selectedMessageForDocDisplay ===
const previousMessage = message.messageId) ||
i !== 0 ? messageHistory[i - 1] : null; (selectedMessageForDocDisplay === -1 &&
return ( i === messageHistory.length - 1);
<div key={i}> const previousMessage =
<AIMessage i !== 0 ? messageHistory[i - 1] : null;
messageId={message.messageId} return (
content={message.message} <div key={i}>
query={messageHistory[i]?.query || undefined} <AIMessage
personaName={livePersona.name} messageId={message.messageId}
citedDocuments={getCitedDocumentsFromMessage( content={message.message}
message query={messageHistory[i]?.query || undefined}
)} personaName={livePersona.name}
isComplete={ citedDocuments={getCitedDocumentsFromMessage(
i !== messageHistory.length - 1 || !isStreaming message
} )}
hasDocs={ isComplete={
(message.documents && i !== messageHistory.length - 1 ||
message.documents.length > 0) === true !isStreaming
} }
handleFeedback={ hasDocs={
i === messageHistory.length - 1 && isStreaming (message.documents &&
? undefined message.documents.length > 0) === true
: (feedbackType) => }
setCurrentFeedback([ handleFeedback={
feedbackType, i === messageHistory.length - 1 &&
message.messageId as number, isStreaming
]) ? undefined
} : (feedbackType) =>
handleSearchQueryEdit={ setCurrentFeedback([
i === messageHistory.length - 1 && !isStreaming feedbackType,
? (newQuery) => { message.messageId as number,
if (!previousMessage) { ])
setPopup({ }
type: "error", handleSearchQueryEdit={
message: i === messageHistory.length - 1 &&
"Cannot edit query of first message - please refresh the page and try again.", !isStreaming
}); ? (newQuery) => {
return; if (!previousMessage) {
} setPopup({
type: "error",
if (previousMessage.messageId === null) { message:
setPopup({ "Cannot edit query of first message - please refresh the page and try again.",
type: "error", });
message: return;
"Cannot edit query of a pending message - please wait a few seconds and try again.", }
});
return; if (
previousMessage.messageId === null
) {
setPopup({
type: "error",
message:
"Cannot edit query of a pending message - please wait a few seconds and try again.",
});
return;
}
onSubmit({
messageIdToResend:
previousMessage.messageId,
queryOverride: newQuery,
});
}
: undefined
}
isCurrentlyShowingRetrieved={
isShowingRetrieved
}
handleShowRetrieved={(messageNumber) => {
if (isShowingRetrieved) {
setSelectedMessageForDocDisplay(null);
} else {
if (messageNumber !== null) {
setSelectedMessageForDocDisplay(
messageNumber
);
} else {
setSelectedMessageForDocDisplay(-1);
} }
}
}}
handleForceSearch={() => {
if (
previousMessage &&
previousMessage.messageId
) {
onSubmit({ onSubmit({
messageIdToResend: messageIdToResend:
previousMessage.messageId, previousMessage.messageId,
queryOverride: newQuery, forceSearch: true,
});
} else {
setPopup({
type: "error",
message:
"Failed to force search - please refresh the page and try again.",
}); });
} }
: undefined }}
} retrievalDisabled={retrievalDisabled}
isCurrentlyShowingRetrieved={isShowingRetrieved}
handleShowRetrieved={(messageNumber) => {
if (isShowingRetrieved) {
setSelectedMessageForDocDisplay(null);
} else {
if (messageNumber !== null) {
setSelectedMessageForDocDisplay(
messageNumber
);
} else {
setSelectedMessageForDocDisplay(-1);
}
}
}}
handleForceSearch={() => {
if (
previousMessage &&
previousMessage.messageId
) {
onSubmit({
messageIdToResend:
previousMessage.messageId,
forceSearch: true,
});
} else {
setPopup({
type: "error",
message:
"Failed to force search - please refresh the page and try again.",
});
}
}}
retrievalDisabled={retrievalDisabled}
/>
</div>
);
} else {
return (
<div key={i}>
<AIMessage
messageId={message.messageId}
personaName={livePersona.name}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
</p>
}
/>
</div>
);
}
})}
{isStreaming &&
messageHistory.length &&
messageHistory[messageHistory.length - 1].type ===
"user" && (
<div key={messageHistory.length}>
<AIMessage
messageId={null}
personaName={livePersona.name}
content={
<div className="text-sm my-auto">
<ThreeDots
height="30"
width="50"
color="#3b82f6"
ariaLabel="grid-loading"
radius="12.5"
wrapperStyle={{}}
wrapperClass=""
visible={true}
/> />
</div> </div>
} );
/> } else {
</div> return (
)} <div key={i}>
<AIMessage
messageId={message.messageId}
personaName={livePersona.name}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
</p>
}
/>
</div>
);
}
})}
{/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/} {isStreaming &&
<div className={`min-h-[30px] w-full`}></div> messageHistory.length &&
messageHistory[messageHistory.length - 1].type ===
"user" && (
<div key={messageHistory.length}>
<AIMessage
messageId={null}
personaName={livePersona.name}
content={
<div className="text-sm my-auto">
<ThreeDots
height="30"
width="50"
color="#3b82f6"
ariaLabel="grid-loading"
radius="12.5"
wrapperStyle={{}}
wrapperClass=""
visible={true}
/>
</div>
}
/>
</div>
)}
{livePersona && {/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/}
livePersona.starter_messages && <div className={`min-h-[30px] w-full`}></div>
livePersona.starter_messages.length > 0 &&
selectedPersona && {livePersona &&
messageHistory.length === 0 && livePersona.starter_messages &&
!isFetchingChatMessages && ( livePersona.starter_messages.length > 0 &&
<div selectedPersona &&
className={` messageHistory.length === 0 &&
!isFetchingChatMessages && (
<div
className={`
mx-auto mx-auto
px-4 px-4
w-searchbar-xs w-searchbar-xs
@ -864,156 +930,193 @@ export function ChatLayout({
mt-4 mt-4
md:grid-cols-2 md:grid-cols-2
mb-6`} mb-6`}
> >
{livePersona.starter_messages.map( {livePersona.starter_messages.map(
(starterMessage, i) => ( (starterMessage, i) => (
<div key={i} className="w-full"> <div key={i} className="w-full">
<StarterMessage <StarterMessage
starterMessage={starterMessage} starterMessage={starterMessage}
onClick={() => onClick={() =>
onSubmit({ onSubmit({
messageOverride: starterMessage.message, messageOverride:
}) starterMessage.message,
} })
/> }
</div> />
) </div>
)
)}
</div>
)} )}
</div>
)}
<div ref={endDivRef} /> <div ref={endDivRef} />
</div>
</div>
<div className="absolute bottom-0 z-10 w-full bg-background border-t border-border">
<div className="w-full pb-4 pt-2">
{!retrievalDisabled && (
<div className="flex">
<div className="w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar mx-auto px-4 pt-1 flex">
{selectedDocuments.length > 0 ? (
<SelectedDocuments
selectedDocuments={selectedDocuments}
/>
) : (
<ChatFilters
{...filterManager}
existingSources={finalAvailableSources}
availableDocumentSets={finalAvailableDocumentSets}
availableTags={availableTags}
/>
)}
</div>
</div> </div>
)} </div>
<div className="flex justify-center py-2 max-w-screen-lg mx-auto mb-2"> <div className="absolute bottom-0 z-10 w-full bg-background border-t border-border">
<div className="w-full shrink relative px-4 w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar mx-auto"> <div className="w-full pb-4 pt-2">
<textarea {!retrievalDisabled && (
ref={textareaRef} <div className="flex">
autoFocus <div className="w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar mx-auto px-4 pt-1 flex">
className={` {selectedDocuments.length > 0 ? (
opacity-100 <SelectedDocuments
w-full selectedDocuments={selectedDocuments}
shrink />
border ) : (
border-border <ChatFilters
rounded-lg {...filterManager}
outline-none existingSources={finalAvailableSources}
placeholder-gray-400 availableDocumentSets={
pl-4 finalAvailableDocumentSets
pr-12 }
py-4 availableTags={availableTags}
overflow-hidden />
h-14 )}
${ </div>
(textareaRef?.current?.scrollHeight || 0) > </div>
MAX_INPUT_HEIGHT )}
? "overflow-y-auto"
: "" <div className="flex justify-center py-2 max-w-screen-lg mx-auto mb-2">
} <div className="w-full shrink relative px-4 w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar mx-auto">
whitespace-normal <div
break-word className={`
overscroll-contain opacity-100
resize-none w-full
`} h-fit
style={{ scrollbarWidth: "thin" }} flex
role="textarea" flex-col
aria-multiline border
placeholder="Ask me anything..." border-border
value={message} rounded-lg
onChange={(e) => setMessage(e.target.value)} [&:has(textarea:focus)]::ring-1
onKeyDown={(event) => { [&:has(textarea:focus)]::ring-black
if ( `}
event.key === "Enter" && >
!event.shiftKey && {currentMessageFileIds.length > 0 && (
message && <div className="flex flex-wrap gap-y-2 px-1">
!isStreaming {currentMessageFileIds.map((fileId) => (
) { <div key={fileId} className="py-1">
onSubmit(); <InputBarPreviewImage
event.preventDefault(); fileId={fileId}
} onDelete={() => {
}} setCurrentMessageFileIds(
suppressContentEditableWarning={true} currentMessageFileIds.filter(
/> (id) => id !== fileId
<div className="absolute bottom-4 right-10"> )
<div );
className={"cursor-pointer"} }}
onClick={() => { />
if (!isStreaming) { </div>
if (message) { ))}
onSubmit(); </div>
} )}
} else { <textarea
setIsCancelled(true); ref={textareaRef}
} className={`
}} m-0
> w-full
{isStreaming ? ( shrink
<FiStopCircle resize-none
size={18} border-0
className={ bg-transparent
"text-emphasis w-9 h-9 p-2 rounded-lg hover:bg-hover" ${
} (textareaRef?.current?.scrollHeight || 0) >
MAX_INPUT_HEIGHT
? "overflow-y-auto"
: ""
}
whitespace-normal
break-word
overscroll-contain
outline-none
placeholder-gray-400
overflow-hidden
resize-none
pl-4
pr-12
py-4
h-14`}
autoFocus
style={{ scrollbarWidth: "thin" }}
role="textarea"
aria-multiline
placeholder="Ask me anything..."
value={message}
onChange={(e) => setMessage(e.target.value)}
onKeyDown={(event) => {
if (
event.key === "Enter" &&
!event.shiftKey &&
message &&
!isStreaming
) {
onSubmit();
event.preventDefault();
}
}}
suppressContentEditableWarning={true}
/> />
) : ( </div>
<FiSend <div className="absolute bottom-2.5 right-10">
size={18} <div
className={ className={"cursor-pointer"}
"text-emphasis w-9 h-9 p-2 rounded-lg " + onClick={() => {
(message ? "bg-blue-200" : "") if (!isStreaming) {
} if (message) {
/> onSubmit();
)} }
} else {
setIsCancelled(true);
}
}}
>
{isStreaming ? (
<FiStopCircle
size={18}
className={
"text-emphasis w-9 h-9 p-2 rounded-lg hover:bg-hover"
}
/>
) : (
<FiSend
size={18}
className={
"text-emphasis w-9 h-9 p-2 rounded-lg " +
(message ? "bg-blue-200" : "")
}
/>
)}
</div>
</div>
</div> </div>
</div> </div>
</div> </div>
</div> </div>
</div> </div>
</div>
</div>
{!retrievalDisabled ? ( {!retrievalDisabled ? (
<ResizableSection <ResizableSection
intialWidth={documentSidebarInitialWidth} intialWidth={documentSidebarInitialWidth as number}
minWidth={400} minWidth={400}
maxWidth={maxDocumentSidebarWidth || undefined} maxWidth={maxDocumentSidebarWidth || undefined}
> >
<DocumentSidebar <DocumentSidebar
selectedMessage={aiMessage} selectedMessage={aiMessage}
selectedDocuments={selectedDocuments} selectedDocuments={selectedDocuments}
toggleDocumentSelection={toggleDocumentSelection} toggleDocumentSelection={toggleDocumentSelection}
clearSelectedDocuments={clearSelectedDocuments} clearSelectedDocuments={clearSelectedDocuments}
selectedDocumentTokens={selectedDocumentTokens} selectedDocumentTokens={selectedDocumentTokens}
maxTokens={maxTokens} maxTokens={maxTokens}
isLoading={isFetchingChatMessages} isLoading={isFetchingChatMessages}
/> />
</ResizableSection> </ResizableSection>
) : // Another option is to use a div with the width set to the initial width, so that the ) : // Another option is to use a div with the width set to the initial width, so that the
// chat section appears in the same place as before // chat section appears in the same place as before
// <div style={documentSidebarInitialWidth ? {width: documentSidebarInitialWidth} : {}}></div> // <div style={documentSidebarInitialWidth ? {width: documentSidebarInitialWidth} : {}}></div>
null} null}
</> </>
)}
</Dropzone>
) : ( ) : (
<div className="mx-auto h-full flex flex-col"> <div className="mx-auto h-full flex flex-col">
<div className="my-auto"> <div className="my-auto">

View File

@ -0,0 +1,50 @@
"use client";
import { useEffect } from "react";
import { buildImgUrl } from "./utils";
import * as Dialog from "@radix-ui/react-dialog";
export function FullImageModal({
fileId,
open,
onOpenChange,
}: {
fileId: string;
open: boolean;
onOpenChange: (open: boolean) => void;
}) {
// pre-fetch image
useEffect(() => {
const img = new Image();
img.src = buildImgUrl(fileId);
}, [fileId]);
return (
<Dialog.Root open={open} onOpenChange={onOpenChange}>
<Dialog.Portal>
<Dialog.Overlay className="fixed inset-0 bg-black bg-opacity-80 z-50" />
<Dialog.Content
className={`fixed
inset-0
flex
items-center
justify-center
p-4
z-[100]
max-w-screen-lg
h-fit
top-1/2
left-1/2
-translate-y-2/4
-translate-x-2/4`}
>
<img
src={buildImgUrl(fileId)}
alt="Uploaded image"
className="max-w-full max-h-full"
/>
</Dialog.Content>
</Dialog.Portal>
</Dialog.Root>
);
}

View File

@ -0,0 +1,33 @@
"use client";
import { useState } from "react";
import { FullImageModal } from "./FullImageModal";
import { buildImgUrl } from "./utils";
export function InMessageImage({ fileId }: { fileId: string }) {
const [fullImageShowing, setFullImageShowing] = useState(false);
return (
<>
<FullImageModal
fileId={fileId}
open={fullImageShowing}
onOpenChange={(open) => setFullImageShowing(open)}
/>
<img
className={`
max-w-lg
rounded-lg
bg-transparent
cursor-pointer
transition-opacity
duration-300
opacity-100`}
onClick={() => setFullImageShowing(true)}
src={buildImgUrl(fileId)}
loading="lazy"
/>
</>
);
}

View File

@ -0,0 +1,46 @@
"use client";
import { useState } from "react";
import { FiX } from "react-icons/fi";
import { buildImgUrl } from "./utils";
import { FullImageModal } from "./FullImageModal";
export function InputBarPreviewImage({
fileId,
onDelete,
}: {
fileId: string;
onDelete: () => void;
}) {
const [isHovered, setIsHovered] = useState(false);
const [fullImageShowing, setFullImageShowing] = useState(false);
return (
<>
<FullImageModal
fileId={fileId}
open={fullImageShowing}
onOpenChange={(open) => setFullImageShowing(open)}
/>
<div
className="p-1 relative"
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
>
{isHovered && (
<button
onClick={onDelete}
className="absolute top-0 right-0 cursor-pointer border-none bg-hover p-1 rounded-full"
>
<FiX />
</button>
)}
<img
onClick={() => setFullImageShowing(true)}
className="h-16 w-16 object-cover rounded-lg bg-background cursor-pointer"
src={buildImgUrl(fileId)}
/>
</div>
</>
);
}

View File

@ -0,0 +1,3 @@
export function buildImgUrl(fileId: string) {
return `/api/chat/file/${fileId}`;
}

View File

@ -20,6 +20,11 @@ export interface RetrievalDetails {
type CitationMap = { [key: string]: number }; type CitationMap = { [key: string]: number };
export interface FileDescriptor {
id: string;
type: "image";
}
export interface ChatSession { export interface ChatSession {
id: number; id: number;
name: string; name: string;
@ -36,6 +41,7 @@ export interface Message {
query?: string | null; query?: string | null;
documents?: DanswerDocument[] | null; documents?: DanswerDocument[] | null;
citations?: CitationMap; citations?: CitationMap;
files: FileDescriptor[];
} }
export interface BackendChatSession { export interface BackendChatSession {
@ -58,6 +64,7 @@ export interface BackendMessage {
message_type: "user" | "assistant" | "system"; message_type: "user" | "assistant" | "system";
time_sent: string; time_sent: string;
citations: CitationMap; citations: CitationMap;
files: FileDescriptor[];
} }
export interface DocumentsResponse { export interface DocumentsResponse {

View File

@ -47,6 +47,7 @@ export async function createChatSession(
export async function* sendMessage({ export async function* sendMessage({
message, message,
fileIds,
parentMessageId, parentMessageId,
chatSessionId, chatSessionId,
promptId, promptId,
@ -60,6 +61,7 @@ export async function* sendMessage({
useExistingUserMessage, useExistingUserMessage,
}: { }: {
message: string; message: string;
fileIds: string[];
parentMessageId: number | null; parentMessageId: number | null;
chatSessionId: number; chatSessionId: number;
promptId: number | null | undefined; promptId: number | null | undefined;
@ -89,6 +91,7 @@ export async function* sendMessage({
message: message, message: message,
prompt_id: promptId, prompt_id: promptId,
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null, search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
file_ids: fileIds,
retrieval_options: !documentsAreSelected retrieval_options: !documentsAreSelected
? { ? {
run_search: run_search:
@ -364,6 +367,7 @@ export function processRawChatHistory(rawMessages: BackendMessage[]) {
messageId: messageInfo.message_id, messageId: messageInfo.message_id,
message: messageInfo.message, message: messageInfo.message,
type: messageInfo.message_type as "user" | "assistant", type: messageInfo.message_type as "user" | "assistant",
files: messageInfo.files,
// only include these fields if this is an assistant message so that // only include these fields if this is an assistant message so that
// this is identical to what is computed at streaming time // this is identical to what is computed at streaming time
...(messageInfo.message_type === "assistant" ...(messageInfo.message_type === "assistant"
@ -419,3 +423,23 @@ export function buildChatUrl(
return "/chat"; return "/chat";
} }
export async function uploadFilesForChat(
files: File[]
): Promise<[string[], string | null]> {
const formData = new FormData();
files.forEach((file) => {
formData.append("files", file);
});
const response = await fetch("/api/chat/file", {
method: "POST",
body: formData,
});
if (!response.ok) {
return [[], `Failed to upload files - ${(await response.json()).detail}`];
}
const responseJson = await response.json();
return [responseJson.file_ids as string[], null];
}

View File

@ -16,6 +16,8 @@ import { ThreeDots } from "react-loader-spinner";
import { SkippedSearch } from "./SkippedSearch"; import { SkippedSearch } from "./SkippedSearch";
import remarkGfm from "remark-gfm"; import remarkGfm from "remark-gfm";
import { CopyButton } from "@/components/CopyButton"; import { CopyButton } from "@/components/CopyButton";
import { FileDescriptor } from "../interfaces";
import { InMessageImage } from "../images/InMessageImage";
export const Hoverable: React.FC<{ export const Hoverable: React.FC<{
children: JSX.Element; children: JSX.Element;
@ -219,8 +221,10 @@ export const AIMessage = ({
export const HumanMessage = ({ export const HumanMessage = ({
content, content,
files,
}: { }: {
content: string | JSX.Element; content: string | JSX.Element;
files?: FileDescriptor[];
}) => { }) => {
return ( return (
<div className="py-5 px-5 flex -mr-6 w-full"> <div className="py-5 px-5 flex -mr-6 w-full">
@ -237,6 +241,16 @@ export const HumanMessage = ({
</div> </div>
<div className="mx-auto mt-1 ml-8 w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar-default flex flex-wrap"> <div className="mx-auto mt-1 ml-8 w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar-default flex flex-wrap">
<div className="w-message-xs 2xl:w-message-sm 3xl:w-message-default break-words"> <div className="w-message-xs 2xl:w-message-sm 3xl:w-message-default break-words">
{files && files.length > 0 && (
<div className="mt-2 mb-4">
<div className="flex flex-wrap gap-2">
{files.map((file) => {
return <InMessageImage key={file.id} fileId={file.id} />;
})}
</div>
</div>
)}
{typeof content === "string" ? ( {typeof content === "string" ? (
<ReactMarkdown <ReactMarkdown
className="prose max-w-full" className="prose max-w-full"

View File

@ -24,11 +24,13 @@ import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
import { cookies } from "next/headers"; import { cookies } from "next/headers";
import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/contants"; import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/contants";
import { personaComparator } from "../admin/assistants/lib"; import { personaComparator } from "../admin/assistants/lib";
import { ChatLayout } from "./ChatPage"; import { ChatPage } from "./ChatPage";
import { FullEmbeddingModelResponse } from "../admin/models/embedding/embeddingModels"; import { FullEmbeddingModelResponse } from "../admin/models/embedding/embeddingModels";
import { NoCompleteSourcesModal } from "@/components/initialSetup/search/NoCompleteSourceModal"; import { NoCompleteSourcesModal } from "@/components/initialSetup/search/NoCompleteSourceModal";
import { Settings } from "../admin/settings/interfaces"; import { Settings } from "../admin/settings/interfaces";
import { SIDEBAR_TAB_COOKIE, Tabs } from "./sessionSidebar/constants"; import { SIDEBAR_TAB_COOKIE, Tabs } from "./sessionSidebar/constants";
import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs";
import { LLMProviderDescriptor } from "../admin/models/llm/interfaces";
export default async function Page({ export default async function Page({
searchParams, searchParams,
@ -45,6 +47,7 @@ export default async function Page({
fetchSS("/persona?include_default=true"), fetchSS("/persona?include_default=true"),
fetchSS("/chat/get-user-chat-sessions"), fetchSS("/chat/get-user-chat-sessions"),
fetchSS("/query/valid-tags"), fetchSS("/query/valid-tags"),
fetchLLMProvidersSS(),
]; ];
// catch cases where the backend is completely unreachable here // catch cases where the backend is completely unreachable here
@ -56,8 +59,9 @@ export default async function Page({
| AuthTypeMetadata | AuthTypeMetadata
| FullEmbeddingModelResponse | FullEmbeddingModelResponse
| Settings | Settings
| LLMProviderDescriptor[]
| null | null
)[] = [null, null, null, null, null, null, null, null, null]; )[] = [null, null, null, null, null, null, null, null, null, null];
try { try {
results = await Promise.all(tasks); results = await Promise.all(tasks);
} catch (e) { } catch (e) {
@ -70,6 +74,7 @@ export default async function Page({
const personasResponse = results[4] as Response | null; const personasResponse = results[4] as Response | null;
const chatSessionsResponse = results[5] as Response | null; const chatSessionsResponse = results[5] as Response | null;
const tagsResponse = results[6] as Response | null; const tagsResponse = results[6] as Response | null;
const llmProviders = (results[7] || []) as LLMProviderDescriptor[];
const authDisabled = authTypeMetadata?.authType === "disabled"; const authDisabled = authTypeMetadata?.authType === "disabled";
if (!authDisabled && !user) { if (!authDisabled && !user) {
@ -177,13 +182,14 @@ export default async function Page({
<NoCompleteSourcesModal ccPairs={ccPairs} /> <NoCompleteSourcesModal ccPairs={ccPairs} />
)} )}
<ChatLayout <ChatPage
user={user} user={user}
chatSessions={chatSessions} chatSessions={chatSessions}
availableSources={availableSources} availableSources={availableSources}
availableDocumentSets={documentSets} availableDocumentSets={documentSets}
availablePersonas={personas} availablePersonas={personas}
availableTags={tags} availableTags={tags}
llmProviders={llmProviders}
defaultSelectedPersonaId={defaultPersonaId} defaultSelectedPersonaId={defaultPersonaId}
documentSidebarInitialWidth={finalDocumentSidebarInitialWidth} documentSidebarInitialWidth={finalDocumentSidebarInitialWidth}
defaultSidebarTab={defaultSidebarTab} defaultSidebarTab={defaultSidebarTab}

View File

@ -0,0 +1,10 @@
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
import { fetchSS } from "../utilsSS";
export async function fetchLLMProvidersSS() {
const response = await fetchSS("/llm/provider");
if (response.ok) {
return (await response.json()) as LLMProviderDescriptor[];
}
return [];
}

33
web/src/lib/llm/utils.ts Normal file
View File

@ -0,0 +1,33 @@
import { Persona } from "@/app/admin/assistants/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
export function getFinalLLM(
llmProviders: LLMProviderDescriptor[],
persona: Persona | null
): [string, string] {
const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider
);
let provider = defaultProvider?.name || "";
let model = defaultProvider?.default_model_name || "";
if (persona) {
provider = persona.llm_model_provider_override || provider;
model = persona.llm_model_version_override || model;
}
return [provider, model];
}
const MODELS_SUPPORTING_IMAGES = [
["openai", "gpt-4-vision-preview"],
["openai", "gpt-4-turbo"],
["openai", "gpt-4-1106-vision-preview"],
];
export function checkLLMSupportsImageInput(provider: string, model: string) {
return MODELS_SUPPORTING_IMAGES.some(
([p, m]) => p === provider && m === model
);
}