Add image upload capabilities

This commit is contained in:
Weves 2024-04-29 01:18:47 -07:00 committed by Chris Weaver
parent 350e548b2d
commit 5b93e786ad
33 changed files with 992 additions and 363 deletions

View File

@ -0,0 +1,27 @@
"""Add files to ChatMessage
Revision ID: ef7da92f7213
Revises: 401c1ac29467
Create Date: 2024-04-28 16:59:33.199153
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "ef7da92f7213"
down_revision = "401c1ac29467"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column("files", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "files")

View File

@ -30,6 +30,8 @@ from danswer.db.engine import get_session_context_manager
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.utils import load_all_chat_files
from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
@ -174,6 +176,10 @@ def stream_chat_message_objects(
message=message_text,
token_count=len(llm_tokenizer_encode_func(message_text)),
message_type=MessageType.USER,
files=[
{"id": str(file_id), "type": ChatFileType.IMAGE}
for file_id in new_msg_req.file_ids
],
db_session=db_session,
commit=False,
)
@ -202,9 +208,20 @@ def stream_chat_message_objects(
"when the last message is not a user message."
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(history_msgs, new_msg_req.file_ids, db_session)
latest_query_files = [
file for file in files if file.file_id in new_msg_req.file_ids
]
run_search = False
# Retrieval options are only None if reference_doc_ids are provided
if retrieval_options is not None and persona.num_chunks != 0:
# Also don't perform search if the user uploaded at least one file - just use the files
if (
retrieval_options is not None
and persona.num_chunks != 0
and not new_msg_req.file_ids
):
if retrieval_options.run_search == OptionalSearchSetting.ALWAYS:
run_search = True
elif retrieval_options.run_search == OptionalSearchSetting.NEVER:
@ -360,6 +377,7 @@ def stream_chat_message_objects(
answer = Answer(
question=final_msg.message,
docs=llm_docs,
latest_query_files=latest_query_files,
answer_style_config=AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=reference_db_search_docs is not None
@ -380,7 +398,7 @@ def stream_chat_message_objects(
),
doc_relevance_list=llm_relevance_list,
message_history=[
PreviousMessage.from_chat_message(msg) for msg in history_msgs
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
)
# generator will not include quotes, so we can cast

View File

@ -23,7 +23,7 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.file_store import get_default_file_store
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@ -17,7 +17,7 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.file_store import get_default_file_store
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@ -30,6 +30,7 @@ from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import StarterMessage
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
@ -256,6 +257,7 @@ def create_new_chat_message(
token_count: int,
message_type: MessageType,
db_session: Session,
files: list[FileDescriptor] | None = None,
rephrased_query: str | None = None,
error: str | None = None,
reference_docs: list[DBSearchDoc] | None = None,
@ -273,6 +275,7 @@ def create_new_chat_message(
token_count=token_count,
message_type=message_type,
citations=citations,
files=files,
error=error,
)
@ -819,6 +822,7 @@ def translate_db_message_to_chat_message_detail(
message_type=chat_message.message_type,
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
)
return chat_msg_detail

View File

@ -42,6 +42,7 @@ from danswer.db.enums import IndexModelStatus
from danswer.db.enums import TaskStatus
from danswer.db.pydantic_type import PydanticType
from danswer.dynamic_configs.interface import JSON_ro
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
@ -629,6 +630,11 @@ class ChatMessage(Base):
)
# Maps the citation numbers to a SearchDoc id
citations: Mapped[dict[int, int]] = mapped_column(postgresql.JSONB(), nullable=True)
# files associated with this message (e.g. images uploaded by the user that the
# user is asking a question of)
files: Mapped[list[FileDescriptor] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Only applies for LLM
error: Mapped[str | None] = mapped_column(Text, nullable=True)
time_sent: Mapped[datetime.datetime] = mapped_column(

View File

@ -0,0 +1,33 @@
import base64
from enum import Enum
from typing import TypedDict
from uuid import UUID
from pydantic import BaseModel
class ChatFileType(str, Enum):
IMAGE = "image"
class FileDescriptor(TypedDict):
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
in Postgres"""
id: str
type: ChatFileType
class InMemoryChatFile(BaseModel):
file_id: UUID
content: bytes
file_type: ChatFileType = ChatFileType.IMAGE
def to_base64(self) -> str:
return base64.b64encode(self.content).decode()
def to_file_descriptor(self) -> FileDescriptor:
return {
"id": str(self.file_id),
"type": self.file_type,
}

View File

@ -0,0 +1,40 @@
from typing import cast
from uuid import UUID
from sqlalchemy.orm import Session
from danswer.db.models import ChatMessage
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import InMemoryChatFile
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def build_chat_file_name(file_id: UUID | str) -> str:
return f"chat__{file_id}"
def load_chat_file(file_id: UUID, db_session: Session) -> InMemoryChatFile:
file_io = get_default_file_store(db_session).read_file(
build_chat_file_name(file_id), mode="b"
)
return InMemoryChatFile(file_id=file_id, content=file_io.read())
def load_all_chat_files(
chat_messages: list[ChatMessage], new_file_ids: list[UUID], db_session: Session
) -> list[InMemoryChatFile]:
file_ids_for_history = []
for chat_message in chat_messages:
if chat_message.files:
file_ids_for_history.extend([file["id"] for file in chat_message.files])
files = cast(
list[InMemoryChatFile],
run_functions_tuples_in_parallel(
[
(load_chat_file, (file_id, db_session))
for file_id in new_file_ids + file_ids_for_history
]
),
)
return files

View File

@ -9,7 +9,7 @@ from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.doc_pruning import prune_documents
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
@ -58,7 +58,9 @@ class Answer:
doc_relevance_list: list[bool] | None = None,
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
timeout: int = QA_TIMEOUT,
# newly passed in files to include as part of this question
latest_query_files: list[InMemoryChatFile] | None = None,
files: list[InMemoryChatFile] | None = None,
) -> None:
if single_message_history and message_history:
raise ValueError(
@ -67,6 +69,10 @@ class Answer:
self.question = question
self.docs = docs
self.latest_query_files = latest_query_files or []
self.file_id_to_file = {file.file_id: file for file in (files or [])}
self.doc_relevance_list = doc_relevance_list
self.message_history = message_history or []
# used for QA flow where we only want to send a single message
@ -112,11 +118,15 @@ class Answer:
llm_config=self.llm.config,
prompt_config=self.prompt_config,
context_docs=self.pruned_docs,
latest_query_files=self.latest_query_files,
all_doc_useful=self.answer_style_config.citation_config.all_docs_useful,
llm_tokenizer_encode_func=self.llm_tokenizer.encode,
history_message=self.single_message_history or "",
)
elif self.answer_style_config.quotes_config:
# NOTE: quotes prompt doesn't currently support files
# this is okay for now, since the search UI (which uses this)
# doesn't support image upload
self._final_prompt = build_quotes_prompt(
question=self.question,
context_docs=self.pruned_docs,

View File

@ -9,6 +9,7 @@ from pydantic import root_validator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.configs.constants import MessageType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.override_models import PromptOverride
if TYPE_CHECKING:
@ -25,13 +26,24 @@ class PreviousMessage(BaseModel):
message: str
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
@classmethod
def from_chat_message(cls, chat_message: "ChatMessage") -> "PreviousMessage":
def from_chat_message(
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
) -> "PreviousMessage":
message_file_ids = (
[file["id"] for file in chat_message.files] if chat_message.files else []
)
return cls(
message=chat_message.message,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
files=[
file
for file in available_files
if str(file.file_id) in message_file_ids
],
)

View File

@ -11,10 +11,12 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_default_prompt
from danswer.db.models import Persona
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_max_input_tokens
@ -59,7 +61,7 @@ def find_last_index(lst: list[int], max_prompt_tokens: int) -> int:
return last_ind
def drop_messages_history_overflow(
def _drop_messages_history_overflow(
system_msg: BaseMessage | None,
system_token_count: int,
history_msgs: list[BaseMessage],
@ -171,7 +173,7 @@ def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int:
@lru_cache()
def build_system_message(
def _build_system_message(
prompt_config: PromptConfig,
context_exists: bool,
llm_tokenizer_encode_func: Callable,
@ -201,10 +203,11 @@ def build_system_message(
return system_msg, token_count
def build_user_message(
def _build_user_message(
question: str,
prompt_config: PromptConfig,
context_docs: list[LlmDoc] | list[InferenceChunk],
files: list[InMemoryChatFile],
all_doc_useful: bool,
history_message: str,
) -> tuple[HumanMessage, int]:
@ -222,7 +225,11 @@ def build_user_message(
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt)
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files)
if files
else user_prompt
)
return user_msg, token_count
context_docs_str = build_complete_context_str(context_docs)
@ -240,7 +247,9 @@ def build_user_message(
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer_encode_func(user_prompt))
user_msg = HumanMessage(content=user_prompt)
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
)
return user_msg, token_count
@ -251,13 +260,14 @@ def build_citations_prompt(
prompt_config: PromptConfig,
llm_config: LLMConfig,
context_docs: list[LlmDoc] | list[InferenceChunk],
latest_query_files: list[InMemoryChatFile],
all_doc_useful: bool,
history_message: str,
llm_tokenizer_encode_func: Callable,
) -> list[BaseMessage]:
context_exists = len(context_docs) > 0
system_message_or_none, system_tokens = build_system_message(
system_message_or_none, system_tokens = _build_system_message(
prompt_config=prompt_config,
context_exists=context_exists,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
@ -269,15 +279,16 @@ def build_citations_prompt(
# Be sure the context_docs passed to build_chat_user_message
# Is the same as passed in later for extracting citations
user_message, user_tokens = build_user_message(
user_message, user_tokens = _build_user_message(
question=question,
prompt_config=prompt_config,
context_docs=context_docs,
files=latest_query_files,
all_doc_useful=all_doc_useful,
history_message=history_message,
)
final_prompt_msgs = drop_messages_history_overflow(
final_prompt_msgs = _drop_messages_history_overflow(
system_msg=system_message_or_none,
system_token_count=system_tokens,
history_msgs=history_basemessages,

View File

@ -18,8 +18,10 @@ class WellKnownLLMProviderDescriptor(BaseModel):
OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"gpt-4",
"gpt-4-turbo",
"gpt-4-turbo-preview",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",

View File

@ -2,6 +2,7 @@ from collections.abc import Callable
from collections.abc import Iterator
from copy import copy
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import Union
@ -24,6 +25,7 @@ from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.db.models import ChatMessage
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.interfaces import LLM
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
@ -85,12 +87,17 @@ def tokenizer_trim_chunks(
def translate_danswer_msg_to_langchain(
msg: Union[ChatMessage, "PreviousMessage"],
) -> BaseMessage:
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now
files = [] if isinstance(msg, ChatMessage) else msg.files
content = build_content_with_imgs(msg.message, files)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=msg.message)
return AIMessage(content=content)
if msg.message_type == MessageType.USER:
return HumanMessage(content=msg.message)
return HumanMessage(content=content)
raise ValueError(f"New message type {msg.message_type} not handled")
@ -107,6 +114,33 @@ def translate_history_to_basemessages(
return history_basemessages, history_token_counts
def build_content_with_imgs(
message: str, files: list[InMemoryChatFile]
) -> str | list[str | dict]: # matching Langchain's BaseMessage content type
if not files:
return message
return cast(
list[str | dict],
[
{
"type": "text",
"text": message,
},
]
+ [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{file.to_base64()}",
},
}
for file in files
if file.file_type == "image"
],
)
def dict_based_prompt_to_langchain_prompt(
messages: list[dict[str, str]]
) -> list[BaseMessage]:

View File

@ -215,7 +215,6 @@ def stream_answer_objects(
llm=get_llm_for_persona(persona=chat_session.persona),
doc_relevance_list=search_pipeline.section_relevance_list,
single_message_history=history_str,
timeout=timeout,
)
yield from answer.processed_streamed_output

View File

@ -59,7 +59,6 @@ from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_document_cnts_for_cc_pairs
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.file_store import get_default_file_store
from danswer.db.index_attempt import cancel_indexing_attempts_for_connector
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import create_index_attempt
@ -67,6 +66,7 @@ from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempts
from danswer.db.models import User
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl
from danswer.server.documents.models import ConnectorBase

View File

@ -24,13 +24,13 @@ from danswer.db.engine import get_session
from danswer.db.feedback import fetch_docs_ranked_by_boost
from danswer.db.feedback import update_document_boost
from danswer.db.feedback import update_document_hidden
from danswer.db.file_store import get_default_file_store
from danswer.db.index_attempt import cancel_indexing_attempts_for_connector
from danswer.db.models import User
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import test_llm
from danswer.server.documents.models import ConnectorCredentialPairIdentifier

View File

@ -1,6 +1,10 @@
import uuid
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from fastapi import UploadFile
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
@ -28,6 +32,8 @@ from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.models import User
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.utils import build_chat_file_name
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
@ -404,3 +410,50 @@ def seed_chat(
return ChatSeedResponse(
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}"
)
"""File upload"""
@router.post("/file")
def upload_files_for_chat(
files: list[UploadFile],
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> dict[str, list[uuid.UUID]]:
for file in files:
if file.content_type not in ("image/jpeg", "image/png", "image/webp"):
raise HTTPException(
status_code=400,
detail="Only .jpg, .jpeg, .png, and .webp files are currently supported",
)
if file.size and file.size > 20 * 1024 * 1024:
raise HTTPException(
status_code=400,
detail="File size must be less than 20MB",
)
file_store = get_default_file_store(db_session)
file_ids = []
for file in files:
file_id = uuid.uuid4()
file_name = build_chat_file_name(file_id)
file_store.save_file(file_name=file_name, content=file.file)
file_ids.append(file_id)
return {"file_ids": file_ids}
@router.get("/file/{file_id}")
def fetch_chat_file(
file_id: str,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> Response:
file_store = get_default_file_store(db_session)
file_io = file_store.read_file(build_chat_file_name(file_id), mode="b")
# NOTE: specifying "image/jpeg" here, but it still works for pngs
# TODO: do this properly
return Response(content=file_io.read(), media_type="image/jpeg")

View File

@ -1,5 +1,6 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import root_validator
@ -9,6 +10,7 @@ from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.db.enums import ChatSessionSharedStatus
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.models import BaseFilters
@ -81,6 +83,8 @@ class CreateChatMessageRequest(ChunkContext):
parent_message_id: int | None
# New message contents
message: str
# file's that we should attach to this message
file_ids: list[UUID]
# If no prompt provided, uses the largest prompt of the chat session
# but really this should be explicitly specified, only in the simplified APIs is this inferred
# Use prompt_id 0 to use the system default prompt which is Answer-Question
@ -171,6 +175,7 @@ class ChatMessageDetail(BaseModel):
time_sent: datetime
# Dict mapping citation number to db_doc_id
citations: dict[int, int] | None
files: list[FileDescriptor]
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore

37
web/package-lock.json generated
View File

@ -12,6 +12,7 @@
"@dnd-kit/modifiers": "^7.0.0",
"@dnd-kit/sortable": "^8.0.0",
"@phosphor-icons/react": "^2.0.8",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-popover": "^1.0.7",
"@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3",
@ -1221,6 +1222,42 @@
}
}
},
"node_modules/@radix-ui/react-dialog": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/@radix-ui/react-dialog/-/react-dialog-1.0.5.tgz",
"integrity": "sha512-GjWJX/AUpB703eEBanuBnIWdIXg6NvJFCXcNlSZk4xdszCdhrJgBoUd1cGk67vFO+WdA2pfI/plOpqz/5GUP6Q==",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/primitive": "1.0.1",
"@radix-ui/react-compose-refs": "1.0.1",
"@radix-ui/react-context": "1.0.1",
"@radix-ui/react-dismissable-layer": "1.0.5",
"@radix-ui/react-focus-guards": "1.0.1",
"@radix-ui/react-focus-scope": "1.0.4",
"@radix-ui/react-id": "1.0.1",
"@radix-ui/react-portal": "1.0.4",
"@radix-ui/react-presence": "1.0.1",
"@radix-ui/react-primitive": "1.0.3",
"@radix-ui/react-slot": "1.0.2",
"@radix-ui/react-use-controllable-state": "1.0.1",
"aria-hidden": "^1.1.1",
"react-remove-scroll": "2.5.5"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0",
"react-dom": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-dismissable-layer": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.5.tgz",

View File

@ -13,6 +13,7 @@
"@dnd-kit/modifiers": "^7.0.0",
"@dnd-kit/sortable": "^8.0.0",
"@phosphor-icons/react": "^2.0.8",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-popover": "^1.0.7",
"@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3",

View File

@ -27,3 +27,11 @@ export interface FullLLMProvider extends LLMProvider {
is_default_provider: boolean | null;
model_names: string[];
}
export interface LLMProviderDescriptor {
name: string;
model_names: string[];
default_model_name: string;
fast_default_model_name: string | null;
is_default_provider: boolean | null;
}

View File

@ -7,6 +7,7 @@ import {
ChatSession,
ChatSessionSharedStatus,
DocumentsResponse,
FileDescriptor,
Message,
RetrievalType,
StreamingError,
@ -30,6 +31,7 @@ import {
personaIncludesRetrieval,
processRawChatHistory,
sendMessage,
uploadFilesForChat,
} from "./lib";
import { useContext, useEffect, useRef, useState } from "react";
import { usePopup } from "@/components/admin/connectors/Popup";
@ -56,16 +58,21 @@ import { AnswerPiecePacket, DanswerDocument } from "@/lib/search/interfaces";
import { buildFilters } from "@/lib/search/utils";
import { Tabs } from "./sessionSidebar/constants";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
import { LLMProviderDescriptor } from "../admin/models/llm/interfaces";
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
import { InputBarPreviewImage } from "./images/InputBarPreviewImage";
const MAX_INPUT_HEIGHT = 200;
export function ChatLayout({
export function ChatPage({
user,
chatSessions,
availableSources,
availableDocumentSets,
availablePersonas,
availableTags,
llmProviders,
defaultSelectedPersonaId,
documentSidebarInitialWidth,
defaultSidebarTab,
@ -76,6 +83,7 @@ export function ChatLayout({
availableDocumentSets: DocumentSet[];
availablePersonas: Persona[];
availableTags: Tag[];
llmProviders: LLMProviderDescriptor[];
defaultSelectedPersonaId?: number; // what persona to default to
documentSidebarInitialWidth?: number;
defaultSidebarTab?: Tabs;
@ -123,6 +131,9 @@ export function ChatLayout({
filterManager.setSelectedSources([]);
filterManager.setSelectedTags([]);
filterManager.setTimeRange(null);
// remove uploaded files
setCurrentMessageFileIds([]);
if (isStreaming) {
setIsCancelled(true);
}
@ -209,6 +220,11 @@ export function ChatLayout({
const [messageHistory, setMessageHistory] = useState<Message[]>([]);
const [isStreaming, setIsStreaming] = useState(false);
// uploaded files
const [currentMessageFileIds, setCurrentMessageFileIds] = useState<string[]>(
[]
);
// for document display
// NOTE: -1 is a special designation that means the latest AI message
const [selectedMessageForDocDisplay, setSelectedMessageForDocDisplay] =
@ -404,15 +420,21 @@ export function ChatLayout({
messageToResendIndex !== null
? messageHistory.slice(0, messageToResendIndex)
: messageHistory;
const currFiles = currentMessageFileIds.map((id) => ({
id,
type: "image",
})) as FileDescriptor[];
setMessageHistory([
...currMessageHistory,
{
messageId: 0,
message: currMessage,
type: "user",
files: currFiles,
},
]);
setMessage("");
setCurrentMessageFileIds([]);
setIsStreaming(true);
let answer = "";
@ -429,6 +451,7 @@ export function ChatLayout({
getLastSuccessfulMessageId(currMessageHistory);
for await (const packetBunch of sendMessage({
message: currMessage,
fileIds: currentMessageFileIds,
parentMessageId: lastSuccessfulMessageId,
chatSessionId: currChatSessionId,
promptId: livePersona?.prompts[0]?.id || 0,
@ -479,6 +502,7 @@ export function ChatLayout({
messageId: finalMessage?.parent_message || null,
message: currMessage,
type: "user",
files: currFiles,
},
{
messageId: finalMessage?.message_id || null,
@ -488,6 +512,7 @@ export function ChatLayout({
query: finalMessage?.rephrased_query || query,
documents: finalMessage?.context_docs?.top_documents || documents,
citations: finalMessage?.citations || {},
files: finalMessage?.files || [],
},
]);
if (isCancelledRef.current) {
@ -503,11 +528,13 @@ export function ChatLayout({
messageId: null,
message: currMessage,
type: "user",
files: currFiles,
},
{
messageId: null,
message: errorMsg,
type: "error",
files: [],
},
]);
}
@ -570,7 +597,10 @@ export function ChatLayout({
};
const onPersonaChange = (persona: Persona | null) => {
if (persona) {
if (persona && persona.id !== livePersona.id) {
// remove uploaded files
setCurrentMessageFileIds([]);
setSelectedPersona(persona);
textareaRef.current?.focus();
router.push(buildChatUrl(searchParams, null, persona.id));
@ -638,220 +668,256 @@ export function ChatLayout({
)}
{documentSidebarInitialWidth !== undefined ? (
<>
<div
className={`w-full sm:relative h-screen ${
retrievalDisabled ? "pb-[111px]" : "pb-[140px]"
}`}
>
<div
className={`w-full h-full ${HEADER_PADDING} flex flex-col overflow-y-auto overflow-x-hidden relative`}
ref={scrollableDivRef}
>
{livePersona && (
<div className="sticky top-0 left-80 z-10 w-full bg-background/90 flex">
<div className="ml-2 p-1 rounded mt-2 w-fit">
<ChatPersonaSelector
personas={availablePersonas}
selectedPersonaId={livePersona.id}
onPersonaChange={onPersonaChange}
/>
</div>
<Dropzone
onDrop={(acceptedFiles) => {
uploadFilesForChat(acceptedFiles).then(([fileIds, error]) => {
if (error) {
setPopup({
type: "error",
message: error,
});
} else {
const newFileIds = [...currentMessageFileIds, ...fileIds];
setCurrentMessageFileIds(newFileIds);
}
});
}}
noClick
disabled={
!checkLLMSupportsImageInput(
...getFinalLLM(llmProviders, livePersona)
)
}
onDragLeave={() => console.log("buh")}
onDragEnter={() => console.log("floppa")}
>
{({ getRootProps }) => (
<>
<div
className={`w-full sm:relative h-screen ${
retrievalDisabled ? "pb-[111px]" : "pb-[140px]"
}`}
{...getRootProps()}
>
{/* <input {...getInputProps()} /> */}
<div
className={`w-full h-full ${HEADER_PADDING} flex flex-col overflow-y-auto overflow-x-hidden relative`}
ref={scrollableDivRef}
>
{livePersona && (
<div className="sticky top-0 left-80 z-10 w-full bg-background/90 flex">
<div className="ml-2 p-1 rounded mt-2 w-fit">
<ChatPersonaSelector
personas={availablePersonas}
selectedPersonaId={livePersona.id}
onPersonaChange={onPersonaChange}
/>
</div>
{chatSessionId !== null && (
<div
onClick={() => setSharingModalVisible(true)}
className="ml-auto mr-6 my-auto border-border border p-2 rounded cursor-pointer hover:bg-hover-light"
>
<FiShare2 />
{chatSessionId !== null && (
<div
onClick={() => setSharingModalVisible(true)}
className="ml-auto mr-6 my-auto border-border border p-2 rounded cursor-pointer hover:bg-hover-light"
>
<FiShare2 />
</div>
)}
</div>
)}
</div>
)}
{messageHistory.length === 0 &&
!isFetchingChatMessages &&
!isStreaming && (
<ChatIntro
availableSources={finalAvailableSources}
availablePersonas={availablePersonas}
selectedPersona={selectedPersona}
handlePersonaSelect={(persona) => {
setSelectedPersona(persona);
textareaRef.current?.focus();
router.push(
buildChatUrl(searchParams, null, persona.id)
);
}}
/>
)}
{messageHistory.length === 0 &&
!isFetchingChatMessages &&
!isStreaming && (
<ChatIntro
availableSources={finalAvailableSources}
availablePersonas={availablePersonas}
selectedPersona={selectedPersona}
handlePersonaSelect={(persona) => {
setSelectedPersona(persona);
textareaRef.current?.focus();
router.push(
buildChatUrl(searchParams, null, persona.id)
);
}}
/>
)}
<div
className={
"mt-4 pt-12 sm:pt-0 mx-8" +
(hasPerformedInitialScroll ? "" : " invisible")
}
>
{messageHistory.map((message, i) => {
if (message.type === "user") {
return (
<div key={i}>
<HumanMessage content={message.message} />
</div>
);
} else if (message.type === "assistant") {
const isShowingRetrieved =
(selectedMessageForDocDisplay !== null &&
selectedMessageForDocDisplay ===
message.messageId) ||
(selectedMessageForDocDisplay === -1 &&
i === messageHistory.length - 1);
const previousMessage =
i !== 0 ? messageHistory[i - 1] : null;
return (
<div key={i}>
<AIMessage
messageId={message.messageId}
content={message.message}
query={messageHistory[i]?.query || undefined}
personaName={livePersona.name}
citedDocuments={getCitedDocumentsFromMessage(
message
)}
isComplete={
i !== messageHistory.length - 1 || !isStreaming
}
hasDocs={
(message.documents &&
message.documents.length > 0) === true
}
handleFeedback={
i === messageHistory.length - 1 && isStreaming
? undefined
: (feedbackType) =>
setCurrentFeedback([
feedbackType,
message.messageId as number,
])
}
handleSearchQueryEdit={
i === messageHistory.length - 1 && !isStreaming
? (newQuery) => {
if (!previousMessage) {
setPopup({
type: "error",
message:
"Cannot edit query of first message - please refresh the page and try again.",
});
return;
}
if (previousMessage.messageId === null) {
setPopup({
type: "error",
message:
"Cannot edit query of a pending message - please wait a few seconds and try again.",
});
return;
<div
className={
"mt-4 pt-12 sm:pt-0 mx-8" +
(hasPerformedInitialScroll ? "" : " invisible")
}
>
{messageHistory.map((message, i) => {
if (message.type === "user") {
return (
<div key={i}>
<HumanMessage
content={message.message}
files={message.files}
/>
</div>
);
} else if (message.type === "assistant") {
const isShowingRetrieved =
(selectedMessageForDocDisplay !== null &&
selectedMessageForDocDisplay ===
message.messageId) ||
(selectedMessageForDocDisplay === -1 &&
i === messageHistory.length - 1);
const previousMessage =
i !== 0 ? messageHistory[i - 1] : null;
return (
<div key={i}>
<AIMessage
messageId={message.messageId}
content={message.message}
query={messageHistory[i]?.query || undefined}
personaName={livePersona.name}
citedDocuments={getCitedDocumentsFromMessage(
message
)}
isComplete={
i !== messageHistory.length - 1 ||
!isStreaming
}
hasDocs={
(message.documents &&
message.documents.length > 0) === true
}
handleFeedback={
i === messageHistory.length - 1 &&
isStreaming
? undefined
: (feedbackType) =>
setCurrentFeedback([
feedbackType,
message.messageId as number,
])
}
handleSearchQueryEdit={
i === messageHistory.length - 1 &&
!isStreaming
? (newQuery) => {
if (!previousMessage) {
setPopup({
type: "error",
message:
"Cannot edit query of first message - please refresh the page and try again.",
});
return;
}
if (
previousMessage.messageId === null
) {
setPopup({
type: "error",
message:
"Cannot edit query of a pending message - please wait a few seconds and try again.",
});
return;
}
onSubmit({
messageIdToResend:
previousMessage.messageId,
queryOverride: newQuery,
});
}
: undefined
}
isCurrentlyShowingRetrieved={
isShowingRetrieved
}
handleShowRetrieved={(messageNumber) => {
if (isShowingRetrieved) {
setSelectedMessageForDocDisplay(null);
} else {
if (messageNumber !== null) {
setSelectedMessageForDocDisplay(
messageNumber
);
} else {
setSelectedMessageForDocDisplay(-1);
}
}
}}
handleForceSearch={() => {
if (
previousMessage &&
previousMessage.messageId
) {
onSubmit({
messageIdToResend:
previousMessage.messageId,
queryOverride: newQuery,
forceSearch: true,
});
} else {
setPopup({
type: "error",
message:
"Failed to force search - please refresh the page and try again.",
});
}
: undefined
}
isCurrentlyShowingRetrieved={isShowingRetrieved}
handleShowRetrieved={(messageNumber) => {
if (isShowingRetrieved) {
setSelectedMessageForDocDisplay(null);
} else {
if (messageNumber !== null) {
setSelectedMessageForDocDisplay(
messageNumber
);
} else {
setSelectedMessageForDocDisplay(-1);
}
}
}}
handleForceSearch={() => {
if (
previousMessage &&
previousMessage.messageId
) {
onSubmit({
messageIdToResend:
previousMessage.messageId,
forceSearch: true,
});
} else {
setPopup({
type: "error",
message:
"Failed to force search - please refresh the page and try again.",
});
}
}}
retrievalDisabled={retrievalDisabled}
/>
</div>
);
} else {
return (
<div key={i}>
<AIMessage
messageId={message.messageId}
personaName={livePersona.name}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
</p>
}
/>
</div>
);
}
})}
{isStreaming &&
messageHistory.length &&
messageHistory[messageHistory.length - 1].type ===
"user" && (
<div key={messageHistory.length}>
<AIMessage
messageId={null}
personaName={livePersona.name}
content={
<div className="text-sm my-auto">
<ThreeDots
height="30"
width="50"
color="#3b82f6"
ariaLabel="grid-loading"
radius="12.5"
wrapperStyle={{}}
wrapperClass=""
visible={true}
}}
retrievalDisabled={retrievalDisabled}
/>
</div>
}
/>
</div>
)}
);
} else {
return (
<div key={i}>
<AIMessage
messageId={message.messageId}
personaName={livePersona.name}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
</p>
}
/>
</div>
);
}
})}
{/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/}
<div className={`min-h-[30px] w-full`}></div>
{isStreaming &&
messageHistory.length &&
messageHistory[messageHistory.length - 1].type ===
"user" && (
<div key={messageHistory.length}>
<AIMessage
messageId={null}
personaName={livePersona.name}
content={
<div className="text-sm my-auto">
<ThreeDots
height="30"
width="50"
color="#3b82f6"
ariaLabel="grid-loading"
radius="12.5"
wrapperStyle={{}}
wrapperClass=""
visible={true}
/>
</div>
}
/>
</div>
)}
{livePersona &&
livePersona.starter_messages &&
livePersona.starter_messages.length > 0 &&
selectedPersona &&
messageHistory.length === 0 &&
!isFetchingChatMessages && (
<div
className={`
{/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/}
<div className={`min-h-[30px] w-full`}></div>
{livePersona &&
livePersona.starter_messages &&
livePersona.starter_messages.length > 0 &&
selectedPersona &&
messageHistory.length === 0 &&
!isFetchingChatMessages && (
<div
className={`
mx-auto
px-4
w-searchbar-xs
@ -864,156 +930,193 @@ export function ChatLayout({
mt-4
md:grid-cols-2
mb-6`}
>
{livePersona.starter_messages.map(
(starterMessage, i) => (
<div key={i} className="w-full">
<StarterMessage
starterMessage={starterMessage}
onClick={() =>
onSubmit({
messageOverride: starterMessage.message,
})
}
/>
</div>
)
>
{livePersona.starter_messages.map(
(starterMessage, i) => (
<div key={i} className="w-full">
<StarterMessage
starterMessage={starterMessage}
onClick={() =>
onSubmit({
messageOverride:
starterMessage.message,
})
}
/>
</div>
)
)}
</div>
)}
</div>
)}
<div ref={endDivRef} />
</div>
</div>
<div className="absolute bottom-0 z-10 w-full bg-background border-t border-border">
<div className="w-full pb-4 pt-2">
{!retrievalDisabled && (
<div className="flex">
<div className="w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar mx-auto px-4 pt-1 flex">
{selectedDocuments.length > 0 ? (
<SelectedDocuments
selectedDocuments={selectedDocuments}
/>
) : (
<ChatFilters
{...filterManager}
existingSources={finalAvailableSources}
availableDocumentSets={finalAvailableDocumentSets}
availableTags={availableTags}
/>
)}
</div>
<div ref={endDivRef} />
</div>
)}
</div>
<div className="flex justify-center py-2 max-w-screen-lg mx-auto mb-2">
<div className="w-full shrink relative px-4 w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar mx-auto">
<textarea
ref={textareaRef}
autoFocus
className={`
opacity-100
w-full
shrink
border
border-border
rounded-lg
outline-none
placeholder-gray-400
pl-4
pr-12
py-4
overflow-hidden
h-14
${
(textareaRef?.current?.scrollHeight || 0) >
MAX_INPUT_HEIGHT
? "overflow-y-auto"
: ""
}
whitespace-normal
break-word
overscroll-contain
resize-none
`}
style={{ scrollbarWidth: "thin" }}
role="textarea"
aria-multiline
placeholder="Ask me anything..."
value={message}
onChange={(e) => setMessage(e.target.value)}
onKeyDown={(event) => {
if (
event.key === "Enter" &&
!event.shiftKey &&
message &&
!isStreaming
) {
onSubmit();
event.preventDefault();
}
}}
suppressContentEditableWarning={true}
/>
<div className="absolute bottom-4 right-10">
<div
className={"cursor-pointer"}
onClick={() => {
if (!isStreaming) {
if (message) {
onSubmit();
}
} else {
setIsCancelled(true);
}
}}
>
{isStreaming ? (
<FiStopCircle
size={18}
className={
"text-emphasis w-9 h-9 p-2 rounded-lg hover:bg-hover"
}
<div className="absolute bottom-0 z-10 w-full bg-background border-t border-border">
<div className="w-full pb-4 pt-2">
{!retrievalDisabled && (
<div className="flex">
<div className="w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar mx-auto px-4 pt-1 flex">
{selectedDocuments.length > 0 ? (
<SelectedDocuments
selectedDocuments={selectedDocuments}
/>
) : (
<ChatFilters
{...filterManager}
existingSources={finalAvailableSources}
availableDocumentSets={
finalAvailableDocumentSets
}
availableTags={availableTags}
/>
)}
</div>
</div>
)}
<div className="flex justify-center py-2 max-w-screen-lg mx-auto mb-2">
<div className="w-full shrink relative px-4 w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar mx-auto">
<div
className={`
opacity-100
w-full
h-fit
flex
flex-col
border
border-border
rounded-lg
[&:has(textarea:focus)]::ring-1
[&:has(textarea:focus)]::ring-black
`}
>
{currentMessageFileIds.length > 0 && (
<div className="flex flex-wrap gap-y-2 px-1">
{currentMessageFileIds.map((fileId) => (
<div key={fileId} className="py-1">
<InputBarPreviewImage
fileId={fileId}
onDelete={() => {
setCurrentMessageFileIds(
currentMessageFileIds.filter(
(id) => id !== fileId
)
);
}}
/>
</div>
))}
</div>
)}
<textarea
ref={textareaRef}
className={`
m-0
w-full
shrink
resize-none
border-0
bg-transparent
${
(textareaRef?.current?.scrollHeight || 0) >
MAX_INPUT_HEIGHT
? "overflow-y-auto"
: ""
}
whitespace-normal
break-word
overscroll-contain
outline-none
placeholder-gray-400
overflow-hidden
resize-none
pl-4
pr-12
py-4
h-14`}
autoFocus
style={{ scrollbarWidth: "thin" }}
role="textarea"
aria-multiline
placeholder="Ask me anything..."
value={message}
onChange={(e) => setMessage(e.target.value)}
onKeyDown={(event) => {
if (
event.key === "Enter" &&
!event.shiftKey &&
message &&
!isStreaming
) {
onSubmit();
event.preventDefault();
}
}}
suppressContentEditableWarning={true}
/>
) : (
<FiSend
size={18}
className={
"text-emphasis w-9 h-9 p-2 rounded-lg " +
(message ? "bg-blue-200" : "")
}
/>
)}
</div>
<div className="absolute bottom-2.5 right-10">
<div
className={"cursor-pointer"}
onClick={() => {
if (!isStreaming) {
if (message) {
onSubmit();
}
} else {
setIsCancelled(true);
}
}}
>
{isStreaming ? (
<FiStopCircle
size={18}
className={
"text-emphasis w-9 h-9 p-2 rounded-lg hover:bg-hover"
}
/>
) : (
<FiSend
size={18}
className={
"text-emphasis w-9 h-9 p-2 rounded-lg " +
(message ? "bg-blue-200" : "")
}
/>
)}
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
{!retrievalDisabled ? (
<ResizableSection
intialWidth={documentSidebarInitialWidth}
minWidth={400}
maxWidth={maxDocumentSidebarWidth || undefined}
>
<DocumentSidebar
selectedMessage={aiMessage}
selectedDocuments={selectedDocuments}
toggleDocumentSelection={toggleDocumentSelection}
clearSelectedDocuments={clearSelectedDocuments}
selectedDocumentTokens={selectedDocumentTokens}
maxTokens={maxTokens}
isLoading={isFetchingChatMessages}
/>
</ResizableSection>
) : // Another option is to use a div with the width set to the initial width, so that the
// chat section appears in the same place as before
// <div style={documentSidebarInitialWidth ? {width: documentSidebarInitialWidth} : {}}></div>
null}
</>
{!retrievalDisabled ? (
<ResizableSection
intialWidth={documentSidebarInitialWidth as number}
minWidth={400}
maxWidth={maxDocumentSidebarWidth || undefined}
>
<DocumentSidebar
selectedMessage={aiMessage}
selectedDocuments={selectedDocuments}
toggleDocumentSelection={toggleDocumentSelection}
clearSelectedDocuments={clearSelectedDocuments}
selectedDocumentTokens={selectedDocumentTokens}
maxTokens={maxTokens}
isLoading={isFetchingChatMessages}
/>
</ResizableSection>
) : // Another option is to use a div with the width set to the initial width, so that the
// chat section appears in the same place as before
// <div style={documentSidebarInitialWidth ? {width: documentSidebarInitialWidth} : {}}></div>
null}
</>
)}
</Dropzone>
) : (
<div className="mx-auto h-full flex flex-col">
<div className="my-auto">

View File

@ -0,0 +1,50 @@
"use client";
import { useEffect } from "react";
import { buildImgUrl } from "./utils";
import * as Dialog from "@radix-ui/react-dialog";
export function FullImageModal({
fileId,
open,
onOpenChange,
}: {
fileId: string;
open: boolean;
onOpenChange: (open: boolean) => void;
}) {
// pre-fetch image
useEffect(() => {
const img = new Image();
img.src = buildImgUrl(fileId);
}, [fileId]);
return (
<Dialog.Root open={open} onOpenChange={onOpenChange}>
<Dialog.Portal>
<Dialog.Overlay className="fixed inset-0 bg-black bg-opacity-80 z-50" />
<Dialog.Content
className={`fixed
inset-0
flex
items-center
justify-center
p-4
z-[100]
max-w-screen-lg
h-fit
top-1/2
left-1/2
-translate-y-2/4
-translate-x-2/4`}
>
<img
src={buildImgUrl(fileId)}
alt="Uploaded image"
className="max-w-full max-h-full"
/>
</Dialog.Content>
</Dialog.Portal>
</Dialog.Root>
);
}

View File

@ -0,0 +1,33 @@
"use client";
import { useState } from "react";
import { FullImageModal } from "./FullImageModal";
import { buildImgUrl } from "./utils";
export function InMessageImage({ fileId }: { fileId: string }) {
const [fullImageShowing, setFullImageShowing] = useState(false);
return (
<>
<FullImageModal
fileId={fileId}
open={fullImageShowing}
onOpenChange={(open) => setFullImageShowing(open)}
/>
<img
className={`
max-w-lg
rounded-lg
bg-transparent
cursor-pointer
transition-opacity
duration-300
opacity-100`}
onClick={() => setFullImageShowing(true)}
src={buildImgUrl(fileId)}
loading="lazy"
/>
</>
);
}

View File

@ -0,0 +1,46 @@
"use client";
import { useState } from "react";
import { FiX } from "react-icons/fi";
import { buildImgUrl } from "./utils";
import { FullImageModal } from "./FullImageModal";
export function InputBarPreviewImage({
fileId,
onDelete,
}: {
fileId: string;
onDelete: () => void;
}) {
const [isHovered, setIsHovered] = useState(false);
const [fullImageShowing, setFullImageShowing] = useState(false);
return (
<>
<FullImageModal
fileId={fileId}
open={fullImageShowing}
onOpenChange={(open) => setFullImageShowing(open)}
/>
<div
className="p-1 relative"
onMouseEnter={() => setIsHovered(true)}
onMouseLeave={() => setIsHovered(false)}
>
{isHovered && (
<button
onClick={onDelete}
className="absolute top-0 right-0 cursor-pointer border-none bg-hover p-1 rounded-full"
>
<FiX />
</button>
)}
<img
onClick={() => setFullImageShowing(true)}
className="h-16 w-16 object-cover rounded-lg bg-background cursor-pointer"
src={buildImgUrl(fileId)}
/>
</div>
</>
);
}

View File

@ -0,0 +1,3 @@
export function buildImgUrl(fileId: string) {
return `/api/chat/file/${fileId}`;
}

View File

@ -20,6 +20,11 @@ export interface RetrievalDetails {
type CitationMap = { [key: string]: number };
export interface FileDescriptor {
id: string;
type: "image";
}
export interface ChatSession {
id: number;
name: string;
@ -36,6 +41,7 @@ export interface Message {
query?: string | null;
documents?: DanswerDocument[] | null;
citations?: CitationMap;
files: FileDescriptor[];
}
export interface BackendChatSession {
@ -58,6 +64,7 @@ export interface BackendMessage {
message_type: "user" | "assistant" | "system";
time_sent: string;
citations: CitationMap;
files: FileDescriptor[];
}
export interface DocumentsResponse {

View File

@ -47,6 +47,7 @@ export async function createChatSession(
export async function* sendMessage({
message,
fileIds,
parentMessageId,
chatSessionId,
promptId,
@ -60,6 +61,7 @@ export async function* sendMessage({
useExistingUserMessage,
}: {
message: string;
fileIds: string[];
parentMessageId: number | null;
chatSessionId: number;
promptId: number | null | undefined;
@ -89,6 +91,7 @@ export async function* sendMessage({
message: message,
prompt_id: promptId,
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
file_ids: fileIds,
retrieval_options: !documentsAreSelected
? {
run_search:
@ -364,6 +367,7 @@ export function processRawChatHistory(rawMessages: BackendMessage[]) {
messageId: messageInfo.message_id,
message: messageInfo.message,
type: messageInfo.message_type as "user" | "assistant",
files: messageInfo.files,
// only include these fields if this is an assistant message so that
// this is identical to what is computed at streaming time
...(messageInfo.message_type === "assistant"
@ -419,3 +423,23 @@ export function buildChatUrl(
return "/chat";
}
export async function uploadFilesForChat(
files: File[]
): Promise<[string[], string | null]> {
const formData = new FormData();
files.forEach((file) => {
formData.append("files", file);
});
const response = await fetch("/api/chat/file", {
method: "POST",
body: formData,
});
if (!response.ok) {
return [[], `Failed to upload files - ${(await response.json()).detail}`];
}
const responseJson = await response.json();
return [responseJson.file_ids as string[], null];
}

View File

@ -16,6 +16,8 @@ import { ThreeDots } from "react-loader-spinner";
import { SkippedSearch } from "./SkippedSearch";
import remarkGfm from "remark-gfm";
import { CopyButton } from "@/components/CopyButton";
import { FileDescriptor } from "../interfaces";
import { InMessageImage } from "../images/InMessageImage";
export const Hoverable: React.FC<{
children: JSX.Element;
@ -219,8 +221,10 @@ export const AIMessage = ({
export const HumanMessage = ({
content,
files,
}: {
content: string | JSX.Element;
files?: FileDescriptor[];
}) => {
return (
<div className="py-5 px-5 flex -mr-6 w-full">
@ -237,6 +241,16 @@ export const HumanMessage = ({
</div>
<div className="mx-auto mt-1 ml-8 w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar-default flex flex-wrap">
<div className="w-message-xs 2xl:w-message-sm 3xl:w-message-default break-words">
{files && files.length > 0 && (
<div className="mt-2 mb-4">
<div className="flex flex-wrap gap-2">
{files.map((file) => {
return <InMessageImage key={file.id} fileId={file.id} />;
})}
</div>
</div>
)}
{typeof content === "string" ? (
<ReactMarkdown
className="prose max-w-full"

View File

@ -24,11 +24,13 @@ import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
import { cookies } from "next/headers";
import { DOCUMENT_SIDEBAR_WIDTH_COOKIE_NAME } from "@/components/resizable/contants";
import { personaComparator } from "../admin/assistants/lib";
import { ChatLayout } from "./ChatPage";
import { ChatPage } from "./ChatPage";
import { FullEmbeddingModelResponse } from "../admin/models/embedding/embeddingModels";
import { NoCompleteSourcesModal } from "@/components/initialSetup/search/NoCompleteSourceModal";
import { Settings } from "../admin/settings/interfaces";
import { SIDEBAR_TAB_COOKIE, Tabs } from "./sessionSidebar/constants";
import { fetchLLMProvidersSS } from "@/lib/llm/fetchLLMs";
import { LLMProviderDescriptor } from "../admin/models/llm/interfaces";
export default async function Page({
searchParams,
@ -45,6 +47,7 @@ export default async function Page({
fetchSS("/persona?include_default=true"),
fetchSS("/chat/get-user-chat-sessions"),
fetchSS("/query/valid-tags"),
fetchLLMProvidersSS(),
];
// catch cases where the backend is completely unreachable here
@ -56,8 +59,9 @@ export default async function Page({
| AuthTypeMetadata
| FullEmbeddingModelResponse
| Settings
| LLMProviderDescriptor[]
| null
)[] = [null, null, null, null, null, null, null, null, null];
)[] = [null, null, null, null, null, null, null, null, null, null];
try {
results = await Promise.all(tasks);
} catch (e) {
@ -70,6 +74,7 @@ export default async function Page({
const personasResponse = results[4] as Response | null;
const chatSessionsResponse = results[5] as Response | null;
const tagsResponse = results[6] as Response | null;
const llmProviders = (results[7] || []) as LLMProviderDescriptor[];
const authDisabled = authTypeMetadata?.authType === "disabled";
if (!authDisabled && !user) {
@ -177,13 +182,14 @@ export default async function Page({
<NoCompleteSourcesModal ccPairs={ccPairs} />
)}
<ChatLayout
<ChatPage
user={user}
chatSessions={chatSessions}
availableSources={availableSources}
availableDocumentSets={documentSets}
availablePersonas={personas}
availableTags={tags}
llmProviders={llmProviders}
defaultSelectedPersonaId={defaultPersonaId}
documentSidebarInitialWidth={finalDocumentSidebarInitialWidth}
defaultSidebarTab={defaultSidebarTab}

View File

@ -0,0 +1,10 @@
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
import { fetchSS } from "../utilsSS";
export async function fetchLLMProvidersSS() {
const response = await fetchSS("/llm/provider");
if (response.ok) {
return (await response.json()) as LLMProviderDescriptor[];
}
return [];
}

33
web/src/lib/llm/utils.ts Normal file
View File

@ -0,0 +1,33 @@
import { Persona } from "@/app/admin/assistants/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/models/llm/interfaces";
export function getFinalLLM(
llmProviders: LLMProviderDescriptor[],
persona: Persona | null
): [string, string] {
const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider
);
let provider = defaultProvider?.name || "";
let model = defaultProvider?.default_model_name || "";
if (persona) {
provider = persona.llm_model_provider_override || provider;
model = persona.llm_model_version_override || model;
}
return [provider, model];
}
const MODELS_SUPPORTING_IMAGES = [
["openai", "gpt-4-vision-preview"],
["openai", "gpt-4-turbo"],
["openai", "gpt-4-1106-vision-preview"],
];
export function checkLLMSupportsImageInput(provider: string, model: string) {
return MODELS_SUPPORTING_IMAGES.some(
([p, m]) => p === provider && m === model
);
}