diff --git a/backend/alembic/versions/3879338f8ba1_add_tool_table.py b/backend/alembic/versions/3879338f8ba1_add_tool_table.py index 242eb6645..f4d5cb78e 100644 --- a/backend/alembic/versions/3879338f8ba1_add_tool_table.py +++ b/backend/alembic/versions/3879338f8ba1_add_tool_table.py @@ -11,8 +11,8 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "3879338f8ba1" down_revision = "f1c6478c3fd8" -branch_labels = None -depends_on = None +branch_labels: None = None +depends_on: None = None def upgrade() -> None: diff --git a/backend/alembic/versions/70f00c45c0f2_more_descriptive_filestore.py b/backend/alembic/versions/70f00c45c0f2_more_descriptive_filestore.py new file mode 100644 index 000000000..3748553c3 --- /dev/null +++ b/backend/alembic/versions/70f00c45c0f2_more_descriptive_filestore.py @@ -0,0 +1,68 @@ +"""More Descriptive Filestore + +Revision ID: 70f00c45c0f2 +Revises: 3879338f8ba1 +Create Date: 2024-05-17 17:51:41.926893 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "70f00c45c0f2" +down_revision = "3879338f8ba1" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.add_column("file_store", sa.Column("display_name", sa.String(), nullable=True)) + op.add_column( + "file_store", + sa.Column( + "file_origin", + sa.String(), + nullable=False, + server_default="connector", # Default to connector + ), + ) + op.add_column( + "file_store", + sa.Column( + "file_type", sa.String(), nullable=False, server_default="text/plain" + ), + ) + op.add_column( + "file_store", + sa.Column( + "file_metadata", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + + op.execute( + """ + UPDATE file_store + SET file_origin = CASE + WHEN file_name LIKE 'chat__%' THEN 'chat_upload' + ELSE 'connector' + END, + file_name = CASE + WHEN file_name LIKE 'chat__%' THEN SUBSTR(file_name, 7) + ELSE file_name + END, + file_type = CASE + WHEN file_name LIKE 'chat__%' THEN 'image/png' + ELSE 'text/plain' + END + """ + ) + + +def downgrade() -> None: + op.drop_column("file_store", "file_metadata") + op.drop_column("file_store", "file_type") + op.drop_column("file_store", "file_origin") + op.drop_column("file_store", "display_name") diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 7ee92d79c..06ac84fa0 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -16,6 +16,7 @@ from danswer.chat.models import StreamingError from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.constants import MessageType +from danswer.db.chat import attach_files_to_chat_message from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_chat_message @@ -240,6 +241,7 @@ def stream_chat_message_objects( else: parent_message = root_message + user_message = None if not use_existing_user_message: # Create new message at the right place in the tree and update the parent's child pointer # Don't commit yet until we verify the chat message chain @@ -250,10 +252,7 @@ def stream_chat_message_objects( message=message_text, token_count=len(llm_tokenizer_encode_func(message_text)), message_type=MessageType.USER, - files=[ - {"id": str(file_id), "type": ChatFileType.IMAGE} - for file_id in new_msg_req.file_ids - ], + files=None, # Need to attach later for optimization to only load files once in parallel db_session=db_session, commit=False, ) @@ -283,11 +282,24 @@ def stream_chat_message_objects( ) # load all files needed for this chat chain in memory - files = load_all_chat_files(history_msgs, new_msg_req.file_ids, db_session) + files = load_all_chat_files( + history_msgs, new_msg_req.file_descriptors, db_session + ) latest_query_files = [ - file for file in files if file.file_id in new_msg_req.file_ids + file + for file in files + if file.file_id in [f["id"] for f in new_msg_req.file_descriptors] ] + if user_message: + attach_files_to_chat_message( + chat_message=user_message, + files=[ + new_file.to_file_descriptor() for new_file in latest_query_files + ], + db_session=db_session, + ) + selected_db_search_docs = None selected_llm_docs: list[LlmDoc] | None = None if reference_doc_ids: diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index b6a7bae59..bfd5644b4 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -130,3 +130,9 @@ class TokenRateLimitScope(str, Enum): USER = "user" USER_GROUP = "user_group" GLOBAL = "global" + + +class FileOrigin(str, Enum): + CHAT_UPLOAD = "chat_upload" + CHAT_IMAGE_GEN = "chat_image_gen" + CONNECTOR = "connector" diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 2d3d8a33d..b4a0b9bf2 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -319,6 +319,17 @@ def set_as_latest_chat_message( db_session.commit() +def attach_files_to_chat_message( + chat_message: ChatMessage, + files: list[FileDescriptor], + db_session: Session, + commit: bool = True, +) -> None: + chat_message.files = files + if commit: + db_session.commit() + + def get_prompt_by_id( prompt_id: int, user: User | None, diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 37a5ded42..f2bc7cdb9 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -35,6 +35,7 @@ from sqlalchemy.types import TypeDecorator from danswer.auth.schemas import UserRole from danswer.configs.constants import DEFAULT_BOOST from danswer.configs.constants import DocumentSource +from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.configs.constants import TokenRateLimitScope @@ -1071,8 +1072,13 @@ class KVStore(Base): class PGFileStore(Base): __tablename__ = "file_store" - file_name = mapped_column(String, primary_key=True) - lobj_oid = mapped_column(Integer, nullable=False) + + file_name: Mapped[str] = mapped_column(String, primary_key=True) + display_name: Mapped[str] = mapped_column(String, nullable=True) + file_origin: Mapped[FileOrigin] = mapped_column(Enum(FileOrigin, native_enum=False)) + file_type: Mapped[str] = mapped_column(String, default="text/plain") + file_metadata: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) + lobj_oid: Mapped[int] = mapped_column(Integer, nullable=False) """ diff --git a/backend/danswer/db/pg_file_store.py b/backend/danswer/db/pg_file_store.py index 91a57adab..7146fc75b 100644 --- a/backend/danswer/db/pg_file_store.py +++ b/backend/danswer/db/pg_file_store.py @@ -4,6 +4,7 @@ from typing import IO from psycopg2.extensions import connection from sqlalchemy.orm import Session +from danswer.configs.constants import FileOrigin from danswer.db.models import PGFileStore from danswer.utils.logger import setup_logger @@ -48,7 +49,14 @@ def delete_lobj_by_id( def upsert_pgfilestore( - file_name: str, lobj_oid: int, db_session: Session, commit: bool = False + file_name: str, + display_name: str | None, + file_origin: FileOrigin, + file_type: str, + lobj_oid: int, + db_session: Session, + commit: bool = False, + file_metadata: dict | None = None, ) -> PGFileStore: pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first() @@ -65,7 +73,14 @@ def upsert_pgfilestore( pgfilestore.lobj_oid = lobj_oid else: - pgfilestore = PGFileStore(file_name=file_name, lobj_oid=lobj_oid) + pgfilestore = PGFileStore( + file_name=file_name, + display_name=display_name, + file_origin=file_origin, + file_type=file_type, + file_metadata=file_metadata, + lobj_oid=lobj_oid, + ) db_session.add(pgfilestore) if commit: diff --git a/backend/danswer/file_processing/extract_file_text.py b/backend/danswer/file_processing/extract_file_text.py index 05989a539..710771eef 100644 --- a/backend/danswer/file_processing/extract_file_text.py +++ b/backend/danswer/file_processing/extract_file_text.py @@ -254,9 +254,12 @@ def file_io_to_text(file: IO[Any]) -> str: def extract_file_text( - file_name: str, + file_name: str | None, file: IO[Any], ) -> str: + if not file_name: + return file_io_to_text(file) + extension = get_file_ext(file_name) if not check_file_ext_is_valid(extension): raise RuntimeError("Unprocessable file type") diff --git a/backend/danswer/file_store/file_store.py b/backend/danswer/file_store/file_store.py index f0a44bf5d..9e131d38c 100644 --- a/backend/danswer/file_store/file_store.py +++ b/backend/danswer/file_store/file_store.py @@ -4,6 +4,7 @@ from typing import IO from sqlalchemy.orm import Session +from danswer.configs.constants import FileOrigin from danswer.db.pg_file_store import create_populate_lobj from danswer.db.pg_file_store import delete_lobj_by_id from danswer.db.pg_file_store import delete_pgfilestore_by_file_name @@ -18,7 +19,14 @@ class FileStore(ABC): """ @abstractmethod - def save_file(self, file_name: str, content: IO) -> None: + def save_file( + self, + file_name: str, + content: IO, + display_name: str | None, + file_origin: FileOrigin, + file_type: str, + ) -> None: """ Save a file to the blob store @@ -26,6 +34,9 @@ class FileStore(ABC): - connector_name: Name of the CC-Pair (as specified by the user in the UI) - file_name: Name of the file to save - content: Contents of the file + - display_name: Display name of the file + - file_origin: Origin of the file + - file_type: Type of the file """ raise NotImplementedError @@ -55,13 +66,25 @@ class PostgresBackedFileStore(FileStore): def __init__(self, db_session: Session): self.db_session = db_session - def save_file(self, file_name: str, content: IO) -> None: + def save_file( + self, + file_name: str, + content: IO, + display_name: str | None, + file_origin: FileOrigin, + file_type: str, + ) -> None: try: - # The large objects in postgres are saved as special objects can can be listed with + # The large objects in postgres are saved as special objects can be listed with # SELECT * FROM pg_largeobject_metadata; obj_id = create_populate_lobj(content=content, db_session=self.db_session) upsert_pgfilestore( - file_name=file_name, lobj_oid=obj_id, db_session=self.db_session + file_name=file_name, + display_name=display_name or file_name, + file_origin=file_origin, + file_type=file_type, + lobj_oid=obj_id, + db_session=self.db_session, ) self.db_session.commit() except Exception: diff --git a/backend/danswer/file_store/models.py b/backend/danswer/file_store/models.py index 510f12cf9..f26fa4ca5 100644 --- a/backend/danswer/file_store/models.py +++ b/backend/danswer/file_store/models.py @@ -1,13 +1,18 @@ import base64 from enum import Enum +from typing import NotRequired from typing import TypedDict -from uuid import UUID from pydantic import BaseModel class ChatFileType(str, Enum): + # Image types only contain the binary data IMAGE = "image" + # Doc types are saved as both the binary, and the parsed text + DOC = "document" + # Plain text only contain the text + PLAIN_TEXT = "plain_text" class FileDescriptor(TypedDict): @@ -16,18 +21,26 @@ class FileDescriptor(TypedDict): id: str type: ChatFileType + name: NotRequired[str | None] class InMemoryChatFile(BaseModel): - file_id: UUID + file_id: str content: bytes - file_type: ChatFileType = ChatFileType.IMAGE + file_type: ChatFileType + filename: str | None = None def to_base64(self) -> str: - return base64.b64encode(self.content).decode() + if self.file_type == ChatFileType.IMAGE: + return base64.b64encode(self.content).decode() + else: + raise RuntimeError( + "Should not be trying to convert a non-image file to base64" + ) def to_file_descriptor(self) -> FileDescriptor: return { "id": str(self.file_id), "type": self.file_type, + "name": self.filename, } diff --git a/backend/danswer/file_store/utils.py b/backend/danswer/file_store/utils.py index e487e8f7c..82c027304 100644 --- a/backend/danswer/file_store/utils.py +++ b/backend/danswer/file_store/utils.py @@ -1,50 +1,56 @@ from io import BytesIO from typing import cast -from uuid import UUID from uuid import uuid4 import requests from sqlalchemy.orm import Session +from danswer.configs.constants import FileOrigin from danswer.db.engine import get_session_context_manager from danswer.db.models import ChatMessage from danswer.file_store.file_store import get_default_file_store +from danswer.file_store.models import FileDescriptor from danswer.file_store.models import InMemoryChatFile from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel -def build_chat_file_name(file_id: UUID | str) -> str: - return f"chat__{file_id}" - - -def load_chat_file(file_id: UUID, db_session: Session) -> InMemoryChatFile: +def load_chat_file( + file_descriptor: FileDescriptor, db_session: Session +) -> InMemoryChatFile: file_io = get_default_file_store(db_session).read_file( - build_chat_file_name(file_id), mode="b" + file_descriptor["id"], mode="b" + ) + return InMemoryChatFile( + file_id=file_descriptor["id"], + content=file_io.read(), + file_type=file_descriptor["type"], + filename=file_descriptor["name"], ) - return InMemoryChatFile(file_id=file_id, content=file_io.read()) def load_all_chat_files( - chat_messages: list[ChatMessage], new_file_ids: list[UUID], db_session: Session + chat_messages: list[ChatMessage], + file_descriptors: list[FileDescriptor], + db_session: Session, ) -> list[InMemoryChatFile]: - file_ids_for_history = [] + file_descriptors_for_history: list[FileDescriptor] = [] for chat_message in chat_messages: if chat_message.files: - file_ids_for_history.extend([file["id"] for file in chat_message.files]) + file_descriptors_for_history.extend(chat_message.files) files = cast( list[InMemoryChatFile], run_functions_tuples_in_parallel( [ - (load_chat_file, (file_id, db_session)) - for file_id in new_file_ids + file_ids_for_history + (load_chat_file, (file, db_session)) + for file in file_descriptors + file_descriptors_for_history ] ), ) return files -def save_file_from_url(url: str) -> UUID: +def save_file_from_url(url: str) -> str: """NOTE: using multiple sessions here, since this is often called using multithreading. In practice, sharing a session has resulted in weird errors.""" @@ -52,15 +58,20 @@ def save_file_from_url(url: str) -> UUID: response = requests.get(url) response.raise_for_status() - file_id = uuid4() - file_name = build_chat_file_name(file_id) + unique_id = str(uuid4()) file_io = BytesIO(response.content) file_store = get_default_file_store(db_session) - file_store.save_file(file_name=file_name, content=file_io) - return file_id + file_store.save_file( + file_name=unique_id, + content=file_io, + display_name="GeneratedImage", + file_origin=FileOrigin.CHAT_IMAGE_GEN, + file_type="image/png;base64", + ) + return unique_id -def save_files_from_urls(urls: list[str]) -> list[UUID]: +def save_files_from_urls(urls: list[str]) -> list[str]: funcs = [(save_file_from_url, (url,)) for url in urls] return run_functions_tuples_in_parallel(funcs) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index b3a88421e..41f8e1090 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -86,6 +86,7 @@ class Answer: message_history: list[PreviousMessage] | None = None, single_message_history: str | None = None, # newly passed in files to include as part of this question + # TODO THIS NEEDS TO BE HANDLED latest_query_files: list[InMemoryChatFile] | None = None, files: list[InMemoryChatFile] | None = None, tools: list[Tool] | None = None, diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 22c096146..a526adddc 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -24,8 +24,10 @@ from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.db.models import ChatMessage +from danswer.file_store.models import ChatFileType from danswer.file_store.models import InMemoryChatFile from danswer.llm.interfaces import LLM +from danswer.prompts.constants import CODE_BLOCK_PAT from danswer.search.models import InferenceChunk from danswer.utils.logger import setup_logger from shared_configs.configs import LOG_LEVEL @@ -113,23 +115,50 @@ def translate_history_to_basemessages( return history_basemessages, history_token_counts +def _build_content( + message: str, + files: list[InMemoryChatFile] | None = None, +) -> str: + """Applies all non-image files.""" + text_files = ( + [file for file in files if file.file_type == ChatFileType.PLAIN_TEXT] + if files + else None + ) + if not text_files: + return message + + final_message_with_files = "FILES:\n\n" + for file in text_files: + file_content = file.content.decode("utf-8") + file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else "" + final_message_with_files += ( + f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n" + ) + final_message_with_files += message + + return final_message_with_files + + def build_content_with_imgs( message: str, files: list[InMemoryChatFile] | None = None, img_urls: list[str] | None = None, ) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type - if not files and not img_urls: - return message - files = files or [] + img_files = [file for file in files if file.file_type == ChatFileType.IMAGE] img_urls = img_urls or [] + message_main_content = _build_content(message, files) + + if not img_files and not img_urls: + return message_main_content return cast( list[str | dict[str, Any]], [ { "type": "text", - "text": message, + "text": message_main_content, }, ] + [ diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index a5e4e524d..f612d7a87 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -16,6 +16,7 @@ from danswer.auth.users import current_user from danswer.background.celery.celery_utils import get_deletion_status from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES from danswer.configs.constants import DocumentSource +from danswer.configs.constants import FileOrigin from danswer.connectors.gmail.connector_auth import delete_gmail_service_account_key from danswer.connectors.gmail.connector_auth import delete_google_app_gmail_cred from danswer.connectors.gmail.connector_auth import get_gmail_auth_url @@ -351,7 +352,13 @@ def upload_files( for file in files: file_path = os.path.join(str(uuid.uuid4()), cast(str, file.filename)) deduped_file_paths.append(file_path) - file_store.save_file(file_name=file_path, content=file.file) + file_store.save_file( + file_name=file_path, + content=file.file, + display_name=file.filename, + file_origin=FileOrigin.CONNECTOR, + file_type=file.content_type or "text/plain", + ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return FileUploadResponse(file_paths=deduped_file_paths) diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 60de8ea96..97dcc62d0 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -1,3 +1,4 @@ +import io import uuid from fastapi import APIRouter @@ -13,6 +14,7 @@ from danswer.auth.users import current_user from danswer.chat.chat_utils import create_chat_chain from danswer.chat.process_message import stream_chat_message from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import FileOrigin from danswer.configs.constants import MessageType from danswer.db.chat import create_chat_session from danswer.db.chat import create_new_chat_message @@ -32,8 +34,10 @@ from danswer.db.feedback import create_doc_retrieval_feedback from danswer.db.models import User from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index +from danswer.file_processing.extract_file_text import extract_file_text from danswer.file_store.file_store import get_default_file_store -from danswer.file_store.utils import build_chat_file_name +from danswer.file_store.models import ChatFileType +from danswer.file_store.models import FileDescriptor from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) @@ -422,15 +426,51 @@ def upload_files_for_chat( files: list[UploadFile], db_session: Session = Depends(get_session), _: User | None = Depends(current_user), -) -> dict[str, list[uuid.UUID]]: - for file in files: - if file.content_type not in ("image/jpeg", "image/png", "image/webp"): - raise HTTPException( - status_code=400, - detail="Only .jpg, .jpeg, .png, and .webp files are currently supported", - ) +) -> dict[str, list[FileDescriptor]]: + image_content_types = {"image/jpeg", "image/png", "image/webp"} + text_content_types = { + "text/plain", + "text/csv", + "text/markdown", + "text/x-markdown", + "text/x-config", + "text/tab-separated-values", + "application/json", + "application/xml", + "application/x-yaml", + } + document_content_types = { + "application/pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "message/rfc822", + "application/epub+zip", + } - if file.size and file.size > 20 * 1024 * 1024: + allowed_content_types = image_content_types.union(text_content_types).union( + document_content_types + ) + + for file in files: + if file.content_type not in allowed_content_types: + if file.content_type in image_content_types: + error_detail = "Unsupported image file type. Supported image types include .jpg, .jpeg, .png, .webp." + elif file.content_type in text_content_types: + error_detail = "Unsupported text file type. Supported text types include .txt, .csv, .md, .mdx, .conf, " + ".log, .tsv." + else: + error_detail = ( + "Unsupported document file type. Supported document types include .pdf, .docx, .pptx, .xlsx, " + ".json, .xml, .yml, .yaml, .eml, .epub." + ) + raise HTTPException(status_code=400, detail=error_detail) + + if ( + file.content_type in image_content_types + and file.size + and file.size > 20 * 1024 * 1024 + ): raise HTTPException( status_code=400, detail="File size must be less than 20MB", @@ -438,14 +478,50 @@ def upload_files_for_chat( file_store = get_default_file_store(db_session) - file_ids = [] + file_info: list[tuple[str, str | None, ChatFileType]] = [] for file in files: - file_id = uuid.uuid4() - file_name = build_chat_file_name(file_id) - file_store.save_file(file_name=file_name, content=file.file) - file_ids.append(file_id) + if file.content_type in image_content_types: + file_type = ChatFileType.IMAGE + elif file.content_type in document_content_types: + file_type = ChatFileType.DOC + else: + file_type = ChatFileType.PLAIN_TEXT - return {"file_ids": file_ids} + # store the raw file + file_id = str(uuid.uuid4()) + file_store.save_file( + file_name=file_id, + content=file.file, + display_name=file.filename, + file_origin=FileOrigin.CHAT_UPLOAD, + file_type=file.content_type or file_type.value, + ) + + # if the file is a doc, extract text and store that so we don't need + # to re-extract it every time we send a message + if file_type == ChatFileType.DOC: + extracted_text = extract_file_text(file_name=file.filename, file=file.file) + text_file_id = str(uuid.uuid4()) + file_store.save_file( + file_name=text_file_id, + content=io.BytesIO(extracted_text.encode()), + display_name=file.filename, + file_origin=FileOrigin.CHAT_UPLOAD, + file_type="text/plain", + ) + # for DOC type, just return this for the FileDescriptor + # as we would always use this as the ID to attach to the + # message + file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT)) + else: + file_info.append((file_id, file.filename, file_type)) + + return { + "files": [ + {"id": file_id, "type": file_type, "name": file_name} + for file_id, file_name, file_type in file_info + ] + } @router.get("/file/{file_id}") @@ -455,7 +531,7 @@ def fetch_chat_file( _: User | None = Depends(current_user), ) -> Response: file_store = get_default_file_store(db_session) - file_io = file_store.read_file(build_chat_file_name(file_id), mode="b") + file_io = file_store.read_file(file_id, mode="b") # NOTE: specifying "image/jpeg" here, but it still works for pngs # TODO: do this properly return Response(content=file_io.read(), media_type="image/jpeg") diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 1f70af298..44e8ab846 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -1,6 +1,5 @@ from datetime import datetime from typing import Any -from uuid import UUID from pydantic import BaseModel from pydantic import root_validator @@ -85,7 +84,7 @@ class CreateChatMessageRequest(ChunkContext): # New message contents message: str # file's that we should attach to this message - file_ids: list[UUID] + file_descriptors: list[FileDescriptor] # If no prompt provided, uses the largest prompt of the chat session # but really this should be explicitly specified, only in the simplified APIs is this inferred # Use prompt_id 0 to use the system default prompt which is Answer-Question diff --git a/web/package-lock.json b/web/package-lock.json index 8031b23f5..af5f5078b 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -14,12 +14,14 @@ "@phosphor-icons/react": "^2.0.8", "@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-tooltip": "^1.0.7", "@tremor/react": "^3.9.2", "@types/js-cookie": "^3.0.3", "@types/lodash": "^4.17.0", "@types/node": "18.15.11", "@types/react": "18.0.32", "@types/react-dom": "18.0.11", + "@types/uuid": "^9.0.8", "autoprefixer": "^10.4.14", "formik": "^2.2.9", "js-cookie": "^3.0.5", @@ -40,6 +42,7 @@ "swr": "^2.1.5", "tailwindcss": "^3.3.1", "typescript": "5.0.3", + "uuid": "^9.0.1", "yup": "^1.1.1" }, "devDependencies": { @@ -1383,6 +1386,40 @@ } } }, + "node_modules/@radix-ui/react-tooltip": { + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.0.7.tgz", + "integrity": "sha512-lPh5iKNFVQ/jav/j6ZrWq3blfDJ0OH9R6FlNUHPMqdLuQ9vwDgFsRxvl8b7Asuy5c8xmoojHUxKHQSOAvMHxyw==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/primitive": "1.0.1", + "@radix-ui/react-compose-refs": "1.0.1", + "@radix-ui/react-context": "1.0.1", + "@radix-ui/react-dismissable-layer": "1.0.5", + "@radix-ui/react-id": "1.0.1", + "@radix-ui/react-popper": "1.1.3", + "@radix-ui/react-portal": "1.0.4", + "@radix-ui/react-presence": "1.0.1", + "@radix-ui/react-primitive": "1.0.3", + "@radix-ui/react-slot": "1.0.2", + "@radix-ui/react-use-controllable-state": "1.0.1", + "@radix-ui/react-visually-hidden": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/react-use-callback-ref": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.1.tgz", @@ -1489,6 +1526,29 @@ } } }, + "node_modules/@radix-ui/react-visually-hidden": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.0.3.tgz", + "integrity": "sha512-D4w41yN5YRKtu464TLnByKzMDG/JlMPHtfZgQAu9v6mNakUqGUI9vUrfQKz8NK41VMm/xbZbh76NUTVtIYqOMA==", + "dependencies": { + "@babel/runtime": "^7.13.10", + "@radix-ui/react-primitive": "1.0.3" + }, + "peerDependencies": { + "@types/react": "*", + "@types/react-dom": "*", + "react": "^16.8 || ^17.0 || ^18.0", + "react-dom": "^16.8 || ^17.0 || ^18.0" + }, + "peerDependenciesMeta": { + "@types/react": { + "optional": true + }, + "@types/react-dom": { + "optional": true + } + } + }, "node_modules/@radix-ui/rect": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/@radix-ui/rect/-/rect-1.0.1.tgz", @@ -1735,6 +1795,11 @@ "resolved": "https://registry.npmjs.org/@types/unist/-/unist-3.0.2.tgz", "integrity": "sha512-dqId9J8K/vGi5Zr7oo212BGii5m3q5Hxlkwy3WpYuKPklmBEvsbMYYyLxAQpSffdLl/gdW0XUpKWFvYmyoWCoQ==" }, + "node_modules/@types/uuid": { + "version": "9.0.8", + "resolved": "https://registry.npmjs.org/@types/uuid/-/uuid-9.0.8.tgz", + "integrity": "sha512-jg+97EGIcY9AGHJJRaaPVgetKDsrTgbRjQ5Msgjh/DQKEFl0DtyRr/VCOyD1T2R1MNeWPK/u7JoGhlDZnKBAfA==" + }, "node_modules/@typescript-eslint/parser": { "version": "7.2.0", "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-7.2.0.tgz", @@ -2344,6 +2409,7 @@ "version": "1.1.11", "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.11.tgz", "integrity": "sha512-iCuPHDFgrHX7H2vEI/5xpz07zSHB00TpugqhmYtVmMO6518mCuRMoOYFldEBl0g187ufozdaHgWKcYFb61qGiA==", + "dev": true, "dependencies": { "balanced-match": "^1.0.0", "concat-map": "0.0.1" @@ -2501,6 +2567,7 @@ "version": "4.1.2", "resolved": "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz", "integrity": "sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==", + "dev": true, "dependencies": { "ansi-styles": "^4.1.0", "supports-color": "^7.1.0" @@ -2657,7 +2724,8 @@ "node_modules/concat-map": { "version": "0.0.1", "resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz", - "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==" + "integrity": "sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==", + "dev": true }, "node_modules/convert-source-map": { "version": "2.0.0", @@ -4185,6 +4253,7 @@ "version": "4.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", + "dev": true, "engines": { "node": ">=8" } @@ -5930,6 +5999,7 @@ "version": "3.1.2", "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "dev": true, "dependencies": { "brace-expansion": "^1.1.7" }, @@ -10216,6 +10286,7 @@ "version": "7.2.0", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", "integrity": "sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==", + "dev": true, "dependencies": { "has-flag": "^4.0.0" }, @@ -10343,7 +10414,8 @@ "node_modules/text-table": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/text-table/-/text-table-0.2.0.tgz", - "integrity": "sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==" + "integrity": "sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==", + "dev": true }, "node_modules/thenify": { "version": "3.3.1", @@ -10788,6 +10860,18 @@ "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==" }, + "node_modules/uuid": { + "version": "9.0.1", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-9.0.1.tgz", + "integrity": "sha512-b+1eJOlsR9K8HJpow9Ok3fiWOWSIcIzXodvv0rQjVoOVNpWMpxf1wZNpt4y9h10odCNrqnYp1OBzRktckBe3sA==", + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "bin": { + "uuid": "dist/bin/uuid" + } + }, "node_modules/vfile": { "version": "6.0.1", "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.1.tgz", @@ -11089,126 +11173,6 @@ "type": "github", "url": "https://github.com/sponsors/wooorm" } - }, - "node_modules/@next/swc-darwin-x64": { - "version": "14.2.3", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-14.2.3.tgz", - "integrity": "sha512-6adp7waE6P1TYFSXpY366xwsOnEXM+y1kgRpjSRVI2CBDOcbRjsJ67Z6EgKIqWIue52d2q/Mx8g9MszARj8IEA==", - "cpu": [ - "x64" - ], - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-linux-arm64-gnu": { - "version": "14.2.3", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.2.3.tgz", - "integrity": "sha512-cuzCE/1G0ZSnTAHJPUT1rPgQx1w5tzSX7POXSLaS7w2nIUJUD+e25QoXD/hMfxbsT9rslEXugWypJMILBj/QsA==", - "cpu": [ - "arm64" - ], - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-linux-arm64-musl": { - "version": "14.2.3", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.2.3.tgz", - "integrity": "sha512-0D4/oMM2Y9Ta3nGuCcQN8jjJjmDPYpHX9OJzqk42NZGJocU2MqhBq5tWkJrUQOQY9N+In9xOdymzapM09GeiZw==", - "cpu": [ - "arm64" - ], - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-linux-x64-gnu": { - "version": "14.2.3", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.2.3.tgz", - "integrity": "sha512-ENPiNnBNDInBLyUU5ii8PMQh+4XLr4pG51tOp6aJ9xqFQ2iRI6IH0Ds2yJkAzNV1CfyagcyzPfROMViS2wOZ9w==", - "cpu": [ - "x64" - ], - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-linux-x64-musl": { - "version": "14.2.3", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.2.3.tgz", - "integrity": "sha512-BTAbq0LnCbF5MtoM7I/9UeUu/8ZBY0i8SFjUMCbPDOLv+un67e2JgyN4pmgfXBwy/I+RHu8q+k+MCkDN6P9ViQ==", - "cpu": [ - "x64" - ], - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-win32-arm64-msvc": { - "version": "14.2.3", - "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.2.3.tgz", - "integrity": "sha512-AEHIw/dhAMLNFJFJIJIyOFDzrzI5bAjI9J26gbO5xhAKHYTZ9Or04BesFPXiAYXDNdrwTP2dQceYA4dL1geu8A==", - "cpu": [ - "arm64" - ], - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-win32-ia32-msvc": { - "version": "14.2.3", - "resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.3.tgz", - "integrity": "sha512-vga40n1q6aYb0CLrM+eEmisfKCR45ixQYXuBXxOOmmoV8sYST9k7E3US32FsY+CkkF7NtzdcebiFT4CHuMSyZw==", - "cpu": [ - "ia32" - ], - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">= 10" - } - }, - "node_modules/@next/swc-win32-x64-msvc": { - "version": "14.2.3", - "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.2.3.tgz", - "integrity": "sha512-Q1/zm43RWynxrO7lW4ehciQVj+5ePBhOK+/K2P7pLFX3JaJ/IZVC69SHidrmZSOkqz7ECIOhhy7XhAFG4JYyHA==", - "cpu": [ - "x64" - ], - "optional": true, - "os": [ - "win32" - ], - "engines": { - "node": ">= 10" - } } } } diff --git a/web/package.json b/web/package.json index 33a634979..ba7b23ea2 100644 --- a/web/package.json +++ b/web/package.json @@ -15,12 +15,14 @@ "@phosphor-icons/react": "^2.0.8", "@radix-ui/react-dialog": "^1.0.5", "@radix-ui/react-popover": "^1.0.7", + "@radix-ui/react-tooltip": "^1.0.7", "@tremor/react": "^3.9.2", "@types/js-cookie": "^3.0.3", "@types/lodash": "^4.17.0", "@types/node": "18.15.11", "@types/react": "18.0.32", "@types/react-dom": "18.0.11", + "@types/uuid": "^9.0.8", "autoprefixer": "^10.4.14", "formik": "^2.2.9", "js-cookie": "^3.0.5", @@ -41,6 +43,7 @@ "swr": "^2.1.5", "tailwindcss": "^3.3.1", "typescript": "5.0.3", + "uuid": "^9.0.1", "yup": "^1.1.1" }, "devDependencies": { diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index a8b4e570d..477d5ca86 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -4,6 +4,7 @@ import { useRouter, useSearchParams } from "next/navigation"; import { BackendChatSession, BackendMessage, + ChatFileType, ChatSession, ChatSessionSharedStatus, DocumentsResponse, @@ -66,12 +67,13 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; import Dropzone from "react-dropzone"; import { LLMProviderDescriptor } from "../admin/models/llm/interfaces"; import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils"; -import { InputBarPreviewImage } from "./images/InputBarPreviewImage"; +import { InputBarPreviewImage } from "./files/images/InputBarPreviewImage"; import { Folder } from "./folders/interfaces"; import { ChatInputBar } from "./input/ChatInputBar"; import { ConfigurationModal } from "./modal/configuration/ConfigurationModal"; import { useChatContext } from "@/components/context/ChatContext"; import { UserDropdown } from "@/components/UserDropdown"; +import { v4 as uuidv4 } from "uuid"; const MAX_INPUT_HEIGHT = 200; const TEMP_USER_MESSAGE_ID = -1; @@ -131,7 +133,7 @@ export function ChatPage({ useEffect(() => { urlChatSessionId.current = existingChatSessionId; - textareaRef.current?.focus(); + textAreaRef.current?.focus(); // only clear things if we're going from one chat session to another if (chatSessionId !== null && existingChatSessionId !== chatSessionId) { @@ -150,7 +152,7 @@ export function ChatPage({ }); llmOverrideManager.setTemperature(null); // remove uploaded files - setCurrentMessageFileIds([]); + setCurrentMessageFiles([]); if (isStreaming) { setIsCancelled(true); @@ -317,9 +319,9 @@ export function ChatPage({ const [isStreaming, setIsStreaming] = useState(false); // uploaded files - const [currentMessageFileIds, setCurrentMessageFileIds] = useState( - [] - ); + const [currentMessageFiles, setCurrentMessageFiles] = useState< + FileDescriptor[] + >([]); // for document display // NOTE: -1 is a special designation that means the latest AI message @@ -423,9 +425,9 @@ export function ChatPage({ }, [isFetchingChatMessages]); // handle re-sizing of the text area - const textareaRef = useRef(null); + const textAreaRef = useRef(null); useEffect(() => { - const textarea = textareaRef.current; + const textarea = textAreaRef.current; if (textarea) { textarea.style.height = "0px"; textarea.style.height = `${Math.min( @@ -528,10 +530,6 @@ export function ChatPage({ (currMessageHistory.length > 0 ? currMessageHistory[currMessageHistory.length - 1] : null); - const currFiles = currentMessageFileIds.map((id) => ({ - id, - type: "image", - })) as FileDescriptor[]; // if we're resending, set the parent's child to null // we will use tempMessages until the regenerated message is complete @@ -540,7 +538,7 @@ export function ChatPage({ messageId: TEMP_USER_MESSAGE_ID, message: currMessage, type: "user", - files: currFiles, + files: currentMessageFiles, parentMessageId: parentMessage?.messageId || null, }, ]; @@ -562,7 +560,7 @@ export function ChatPage({ parentMessage = frozenCompleteMessageMap.get(SYSTEM_MESSAGE_ID) || null; } setMessage(""); - setCurrentMessageFileIds([]); + setCurrentMessageFiles([]); setIsStreaming(true); let answer = ""; @@ -580,7 +578,7 @@ export function ChatPage({ getLastSuccessfulMessageId(currMessageHistory); for await (const packetBunch of sendMessage({ message: currMessage, - fileIds: currentMessageFileIds, + fileDescriptors: currentMessageFiles, parentMessageId: lastSuccessfulMessageId, chatSessionId: currChatSessionId, promptId: livePersona?.prompts[0]?.id || 0, @@ -628,7 +626,7 @@ export function ChatPage({ (fileId) => { return { id: fileId, - type: "image", + type: ChatFileType.IMAGE, }; } ); @@ -662,7 +660,7 @@ export function ChatPage({ messageId: newUserMessageId, message: currMessage, type: "user", - files: currFiles, + files: currentMessageFiles, parentMessageId: parentMessage?.messageId || null, childrenMessageIds: [newAssistantMessageId], latestChildMessageId: newAssistantMessageId, @@ -692,7 +690,7 @@ export function ChatPage({ messageId: TEMP_USER_MESSAGE_ID, message: currMessage, type: "user", - files: currFiles, + files: currentMessageFiles, parentMessageId: null, }, { @@ -769,10 +767,10 @@ export function ChatPage({ const onPersonaChange = (persona: Persona | null) => { if (persona && persona.id !== livePersona.id) { // remove uploaded files - setCurrentMessageFileIds([]); + setCurrentMessageFiles([]); setSelectedPersona(persona); - textareaRef.current?.focus(); + textAreaRef.current?.focus(); router.push(buildChatUrl(searchParams, null, persona.id)); } }; @@ -781,7 +779,10 @@ export function ChatPage({ const llmAcceptsImages = checkLLMSupportsImageInput( ...getFinalLLM(llmProviders, livePersona) ); - if (!llmAcceptsImages) { + const imageFiles = acceptedFiles.filter((file) => + file.type.startsWith("image/") + ); + if (imageFiles.length > 0 && !llmAcceptsImages) { setPopup({ type: "error", message: @@ -790,15 +791,35 @@ export function ChatPage({ return; } - uploadFilesForChat(acceptedFiles).then(([fileIds, error]) => { + const tempFileDescriptors = acceptedFiles.map((file) => ({ + id: uuidv4(), + type: file.type.startsWith("image/") + ? ChatFileType.IMAGE + : ChatFileType.DOCUMENT, + isUploading: true, + })); + + // only show loading spinner for reasonably large files + const totalSize = acceptedFiles.reduce((sum, file) => sum + file.size, 0); + if (totalSize > 50 * 1024) { + setCurrentMessageFiles((prev) => [...prev, ...tempFileDescriptors]); + } + + const removeTempFiles = (prev: FileDescriptor[]) => { + return prev.filter( + (file) => !tempFileDescriptors.some((newFile) => newFile.id === file.id) + ); + }; + + uploadFilesForChat(acceptedFiles).then(([files, error]) => { if (error) { + setCurrentMessageFiles((prev) => removeTempFiles(prev)); setPopup({ type: "error", message: error, }); } else { - const newFileIds = [...currentMessageFileIds, ...fileIds]; - setCurrentMessageFileIds(newFileIds); + setCurrentMessageFiles((prev) => [...removeTempFiles(prev), ...files]); } }); }; @@ -884,38 +905,40 @@ export function ChatPage({ > {/* */}
{livePersona && ( -
-
- -
+
+
+
+ +
-
- {chatSessionId !== null && ( -
setSharingModalVisible(true)} - className={` - my-auto - p-2 - rounded - cursor-pointer - hover:bg-hover-light - `} - > - +
+ {chatSessionId !== null && ( +
setSharingModalVisible(true)} + className={` + my-auto + p-2 + rounded + cursor-pointer + hover:bg-hover-light + `} + > + +
+ )} + +
+
- )} - -
-
@@ -930,7 +953,7 @@ export function ChatPage({ selectedPersona={selectedPersona} handlePersonaSelect={(persona) => { setSelectedPersona(persona); - textareaRef.current?.focus(); + textAreaRef.current?.focus(); router.push( buildChatUrl(searchParams, null, persona.id) ); @@ -1149,7 +1172,7 @@ export function ChatPage({ )} {/* Some padding at the bottom so the search bar has space at the bottom to not cover the last message*/} -
+
{livePersona && livePersona.starter_messages && @@ -1206,10 +1229,11 @@ export function ChatPage({ filterManager={filterManager} llmOverrideManager={llmOverrideManager} selectedAssistant={livePersona} - fileIds={currentMessageFileIds} - setFileIds={setCurrentMessageFileIds} + files={currentMessageFiles} + setFiles={setCurrentMessageFiles} handleFileUpload={handleImageUpload} setConfigModalActiveTab={setConfigModalActiveTab} + textAreaRef={textAreaRef} />
diff --git a/web/src/app/chat/files/InputBarPreview.tsx b/web/src/app/chat/files/InputBarPreview.tsx new file mode 100644 index 000000000..8eee7bbf9 --- /dev/null +++ b/web/src/app/chat/files/InputBarPreview.tsx @@ -0,0 +1,73 @@ +import { useState } from "react"; +import { ChatFileType, FileDescriptor } from "../interfaces"; +import { DocumentPreview } from "./documents/DocumentPreview"; +import { InputBarPreviewImage } from "./images/InputBarPreviewImage"; +import { FiX, FiLoader } from "react-icons/fi"; + +function DeleteButton({ onDelete }: { onDelete: () => void }) { + return ( + + ); +} + +export function InputBarPreview({ + file, + onDelete, + isUploading, +}: { + file: FileDescriptor; + onDelete: () => void; + isUploading: boolean; +}) { + const [isHovered, setIsHovered] = useState(false); + + const renderContent = () => { + if (file.type === ChatFileType.IMAGE) { + return ; + } + return ; + }; + + return ( +
setIsHovered(true)} + onMouseLeave={() => setIsHovered(false)} + > + {isHovered && } + {isUploading && ( +
+ +
+ )} + {renderContent()} +
+ ); +} diff --git a/web/src/app/chat/files/documents/DocumentPreview.tsx b/web/src/app/chat/files/documents/DocumentPreview.tsx new file mode 100644 index 000000000..a38e36872 --- /dev/null +++ b/web/src/app/chat/files/documents/DocumentPreview.tsx @@ -0,0 +1,67 @@ +import { FiFileText } from "react-icons/fi"; +import { useState, useRef, useEffect } from "react"; +import { Tooltip } from "@/components/tooltip/Tooltip"; + +export function DocumentPreview({ + fileName, + maxWidth, +}: { + fileName: string; + maxWidth?: string; +}) { + const [isOverflowing, setIsOverflowing] = useState(false); + const fileNameRef = useRef(null); + + useEffect(() => { + if (fileNameRef.current) { + setIsOverflowing( + fileNameRef.current.scrollWidth > fileNameRef.current.clientWidth + ); + } + }, [fileName]); + + return ( +
+
+
+ +
+
+
+ +
+ {fileName} +
+
+
Document
+
+
+ ); +} diff --git a/web/src/app/chat/images/FullImageModal.tsx b/web/src/app/chat/files/images/FullImageModal.tsx similarity index 100% rename from web/src/app/chat/images/FullImageModal.tsx rename to web/src/app/chat/files/images/FullImageModal.tsx diff --git a/web/src/app/chat/images/InMessageImage.tsx b/web/src/app/chat/files/images/InMessageImage.tsx similarity index 100% rename from web/src/app/chat/images/InMessageImage.tsx rename to web/src/app/chat/files/images/InMessageImage.tsx diff --git a/web/src/app/chat/images/InputBarPreviewImage.tsx b/web/src/app/chat/files/images/InputBarPreviewImage.tsx similarity index 50% rename from web/src/app/chat/images/InputBarPreviewImage.tsx rename to web/src/app/chat/files/images/InputBarPreviewImage.tsx index f2bd58946..372d0be60 100644 --- a/web/src/app/chat/images/InputBarPreviewImage.tsx +++ b/web/src/app/chat/files/images/InputBarPreviewImage.tsx @@ -1,18 +1,10 @@ "use client"; import { useState } from "react"; -import { FiX } from "react-icons/fi"; import { buildImgUrl } from "./utils"; import { FullImageModal } from "./FullImageModal"; -export function InputBarPreviewImage({ - fileId, - onDelete, -}: { - fileId: string; - onDelete: () => void; -}) { - const [isHovered, setIsHovered] = useState(false); +export function InputBarPreviewImage({ fileId }: { fileId: string }) { const [fullImageShowing, setFullImageShowing] = useState(false); return ( @@ -22,19 +14,7 @@ export function InputBarPreviewImage({ open={fullImageShowing} onOpenChange={(open) => setFullImageShowing(open)} /> -
setIsHovered(true)} - onMouseLeave={() => setIsHovered(false)} - > - {isHovered && ( - - )} +
setFullImageShowing(true)} className="h-16 w-16 object-cover rounded-lg bg-background cursor-pointer" diff --git a/web/src/app/chat/images/utils.ts b/web/src/app/chat/files/images/utils.ts similarity index 100% rename from web/src/app/chat/images/utils.ts rename to web/src/app/chat/files/images/utils.ts diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index d7a79c72f..c74cb171e 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -1,4 +1,4 @@ -import React, { useRef } from "react"; +import React, { useEffect, useRef } from "react"; import { FiSend, FiFilter, FiPlusCircle, FiCpu } from "react-icons/fi"; import ChatInputOption from "./ChatInputOption"; import { FaBrain } from "react-icons/fa"; @@ -7,7 +7,10 @@ import { FilterManager, LlmOverride, LlmOverrideManager } from "@/lib/hooks"; import { SelectedFilterDisplay } from "./SelectedFilterDisplay"; import { useChatContext } from "@/components/context/ChatContext"; import { getFinalLLM } from "@/lib/llm/utils"; -import { InputBarPreviewImage } from "../images/InputBarPreviewImage"; +import { FileDescriptor } from "../interfaces"; +import { InputBarPreview } from "../files/InputBarPreview"; + +const MAX_INPUT_HEIGHT = 200; export function ChatInputBar({ message, @@ -19,10 +22,11 @@ export function ChatInputBar({ filterManager, llmOverrideManager, selectedAssistant, - fileIds, - setFileIds, + files, + setFiles, handleFileUpload, setConfigModalActiveTab, + textAreaRef, }: { message: string; setMessage: (message: string) => void; @@ -33,12 +37,23 @@ export function ChatInputBar({ filterManager: FilterManager; llmOverrideManager: LlmOverrideManager; selectedAssistant: Persona; - fileIds: string[]; - setFileIds: (fileIds: string[]) => void; + files: FileDescriptor[]; + setFiles: (files: FileDescriptor[]) => void; handleFileUpload: (files: File[]) => void; setConfigModalActiveTab: (tab: string) => void; + textAreaRef: React.RefObject; }) { - const textareaRef = useRef(null); + // handle re-sizing of the text area + useEffect(() => { + const textarea = textAreaRef.current; + if (textarea) { + textarea.style.height = "0px"; + textarea.style.height = `${Math.min( + textarea.scrollHeight, + MAX_INPUT_HEIGHT + )}px`; + } + }, [message]); const { llmProviders } = useChatContext(); const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant); @@ -63,60 +78,65 @@ export function ChatInputBar({
- {fileIds.length > 0 && ( -
- {fileIds.map((fileId) => ( -
- 0 && ( +
+ {files.map((file) => ( +
+ { - setFileIds(fileIds.filter((id) => id !== fileId)); + setFiles( + files.filter( + (fileInFilter) => fileInFilter.id !== file.id + ) + ); }} + isUploading={file.isUploading || false} />
))}
)}