mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-03 08:20:40 +02:00
797 lines
28 KiB
Python
797 lines
28 KiB
Python
import asyncio
|
|
import io
|
|
import json
|
|
import os
|
|
import uuid
|
|
from collections.abc import Callable
|
|
from collections.abc import Generator
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter
|
|
from fastapi import Depends
|
|
from fastapi import HTTPException
|
|
from fastapi import Request
|
|
from fastapi import Response
|
|
from fastapi import UploadFile
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.auth.users import current_chat_accesssible_user
|
|
from onyx.auth.users import current_user
|
|
from onyx.chat.chat_utils import create_chat_chain
|
|
from onyx.chat.chat_utils import extract_headers
|
|
from onyx.chat.process_message import stream_chat_message
|
|
from onyx.chat.prompt_builder.citations_prompt import (
|
|
compute_max_document_tokens_for_persona,
|
|
)
|
|
from onyx.configs.app_configs import WEB_DOMAIN
|
|
from onyx.configs.constants import FileOrigin
|
|
from onyx.configs.constants import MessageType
|
|
from onyx.configs.constants import MilestoneRecordType
|
|
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
|
from onyx.db.chat import add_chats_to_session_from_slack_thread
|
|
from onyx.db.chat import create_chat_session
|
|
from onyx.db.chat import create_new_chat_message
|
|
from onyx.db.chat import delete_all_chat_sessions_for_user
|
|
from onyx.db.chat import delete_chat_session
|
|
from onyx.db.chat import duplicate_chat_session_for_user_from_slack
|
|
from onyx.db.chat import get_chat_message
|
|
from onyx.db.chat import get_chat_messages_by_session
|
|
from onyx.db.chat import get_chat_session_by_id
|
|
from onyx.db.chat import get_chat_sessions_by_user
|
|
from onyx.db.chat import get_or_create_root_message
|
|
from onyx.db.chat import set_as_latest_chat_message
|
|
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
|
from onyx.db.chat import update_chat_session
|
|
from onyx.db.engine import get_current_tenant_id
|
|
from onyx.db.engine import get_session
|
|
from onyx.db.engine import get_session_with_tenant
|
|
from onyx.db.feedback import create_chat_message_feedback
|
|
from onyx.db.feedback import create_doc_retrieval_feedback
|
|
from onyx.db.models import User
|
|
from onyx.db.persona import get_persona_by_id
|
|
from onyx.file_processing.extract_file_text import docx_to_txt_filename
|
|
from onyx.file_processing.extract_file_text import extract_file_text
|
|
from onyx.file_store.file_store import get_default_file_store
|
|
from onyx.file_store.models import ChatFileType
|
|
from onyx.file_store.models import FileDescriptor
|
|
from onyx.llm.exceptions import GenAIDisabledException
|
|
from onyx.llm.factory import get_default_llms
|
|
from onyx.llm.factory import get_llms_for_persona
|
|
from onyx.natural_language_processing.utils import get_tokenizer
|
|
from onyx.secondary_llm_flows.chat_session_naming import (
|
|
get_renamed_conversation_name,
|
|
)
|
|
from onyx.server.query_and_chat.models import ChatFeedbackRequest
|
|
from onyx.server.query_and_chat.models import ChatMessageIdentifier
|
|
from onyx.server.query_and_chat.models import ChatRenameRequest
|
|
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
|
from onyx.server.query_and_chat.models import ChatSessionDetailResponse
|
|
from onyx.server.query_and_chat.models import ChatSessionDetails
|
|
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
|
from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
|
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
|
from onyx.server.query_and_chat.models import CreateChatSessionID
|
|
from onyx.server.query_and_chat.models import LLMOverride
|
|
from onyx.server.query_and_chat.models import PromptOverride
|
|
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
|
from onyx.server.query_and_chat.models import SearchFeedbackRequest
|
|
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
|
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
|
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
|
|
from onyx.utils.headers import get_custom_tool_additional_request_headers
|
|
from onyx.utils.logger import setup_logger
|
|
from onyx.utils.telemetry import create_milestone_and_report
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
router = APIRouter(prefix="/chat")
|
|
|
|
|
|
@router.get("/get-user-chat-sessions")
|
|
def get_user_chat_sessions(
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> ChatSessionsResponse:
|
|
user_id = user.id if user is not None else None
|
|
|
|
try:
|
|
chat_sessions = get_chat_sessions_by_user(
|
|
user_id=user_id, deleted=False, db_session=db_session
|
|
)
|
|
|
|
except ValueError:
|
|
raise ValueError("Chat session does not exist or has been deleted")
|
|
|
|
return ChatSessionsResponse(
|
|
sessions=[
|
|
ChatSessionDetails(
|
|
id=chat.id,
|
|
name=chat.description,
|
|
persona_id=chat.persona_id,
|
|
time_created=chat.time_created.isoformat(),
|
|
shared_status=chat.shared_status,
|
|
folder_id=chat.folder_id,
|
|
current_alternate_model=chat.current_alternate_model,
|
|
current_temperature_override=chat.temperature_override,
|
|
)
|
|
for chat in chat_sessions
|
|
]
|
|
)
|
|
|
|
|
|
@router.put("/update-chat-session-temperature")
|
|
def update_chat_session_temperature(
|
|
update_thread_req: UpdateChatSessionTemperatureRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
chat_session = get_chat_session_by_id(
|
|
chat_session_id=update_thread_req.chat_session_id,
|
|
user_id=user.id if user is not None else None,
|
|
db_session=db_session,
|
|
)
|
|
|
|
# Validate temperature_override
|
|
if update_thread_req.temperature_override is not None:
|
|
if (
|
|
update_thread_req.temperature_override < 0
|
|
or update_thread_req.temperature_override > 2
|
|
):
|
|
raise HTTPException(
|
|
status_code=400, detail="Temperature must be between 0 and 2"
|
|
)
|
|
|
|
# Additional check for Anthropic models
|
|
if (
|
|
chat_session.current_alternate_model
|
|
and "anthropic" in chat_session.current_alternate_model.lower()
|
|
):
|
|
if update_thread_req.temperature_override > 1:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Temperature for Anthropic models must be between 0 and 1",
|
|
)
|
|
|
|
chat_session.temperature_override = update_thread_req.temperature_override
|
|
|
|
db_session.add(chat_session)
|
|
db_session.commit()
|
|
|
|
|
|
@router.put("/update-chat-session-model")
|
|
def update_chat_session_model(
|
|
update_thread_req: UpdateChatSessionThreadRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
chat_session = get_chat_session_by_id(
|
|
chat_session_id=update_thread_req.chat_session_id,
|
|
user_id=user.id if user is not None else None,
|
|
db_session=db_session,
|
|
)
|
|
chat_session.current_alternate_model = update_thread_req.new_alternate_model
|
|
|
|
db_session.add(chat_session)
|
|
db_session.commit()
|
|
|
|
|
|
@router.get("/get-chat-session/{session_id}")
|
|
def get_chat_session(
|
|
session_id: UUID,
|
|
is_shared: bool = False,
|
|
user: User | None = Depends(current_chat_accesssible_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> ChatSessionDetailResponse:
|
|
user_id = user.id if user is not None else None
|
|
try:
|
|
chat_session = get_chat_session_by_id(
|
|
chat_session_id=session_id,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
is_shared=is_shared,
|
|
)
|
|
except ValueError:
|
|
raise ValueError("Chat session does not exist or has been deleted")
|
|
|
|
# for chat-seeding: if the session is unassigned, assign it now. This is done here
|
|
# to avoid another back and forth between FE -> BE before starting the first
|
|
# message generation
|
|
if chat_session.user_id is None and user_id is not None:
|
|
chat_session.user_id = user_id
|
|
db_session.commit()
|
|
|
|
session_messages = get_chat_messages_by_session(
|
|
chat_session_id=session_id,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
# we already did a permission check above with the call to
|
|
# `get_chat_session_by_id`, so we can skip it here
|
|
skip_permission_check=True,
|
|
# we need the tool call objs anyways, so just fetch them in a single call
|
|
prefetch_tool_calls=True,
|
|
)
|
|
|
|
return ChatSessionDetailResponse(
|
|
chat_session_id=session_id,
|
|
description=chat_session.description,
|
|
persona_id=chat_session.persona_id,
|
|
persona_name=chat_session.persona.name if chat_session.persona else None,
|
|
persona_icon_color=chat_session.persona.icon_color
|
|
if chat_session.persona
|
|
else None,
|
|
persona_icon_shape=chat_session.persona.icon_shape
|
|
if chat_session.persona
|
|
else None,
|
|
current_alternate_model=chat_session.current_alternate_model,
|
|
messages=[
|
|
translate_db_message_to_chat_message_detail(msg) for msg in session_messages
|
|
],
|
|
time_created=chat_session.time_created,
|
|
shared_status=chat_session.shared_status,
|
|
current_temperature_override=chat_session.temperature_override,
|
|
)
|
|
|
|
|
|
@router.post("/create-chat-session")
|
|
def create_new_chat_session(
|
|
chat_session_creation_request: ChatSessionCreationRequest,
|
|
user: User | None = Depends(current_chat_accesssible_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> CreateChatSessionID:
|
|
user_id = user.id if user is not None else None
|
|
try:
|
|
new_chat_session = create_chat_session(
|
|
db_session=db_session,
|
|
description=chat_session_creation_request.description
|
|
or "", # Leave the naming till later to prevent delay
|
|
user_id=user_id,
|
|
persona_id=chat_session_creation_request.persona_id,
|
|
)
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
|
|
|
|
return CreateChatSessionID(chat_session_id=new_chat_session.id)
|
|
|
|
|
|
@router.put("/rename-chat-session")
|
|
def rename_chat_session(
|
|
rename_req: ChatRenameRequest,
|
|
request: Request,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> RenameChatSessionResponse:
|
|
name = rename_req.name
|
|
chat_session_id = rename_req.chat_session_id
|
|
user_id = user.id if user is not None else None
|
|
|
|
if name:
|
|
update_chat_session(
|
|
db_session=db_session,
|
|
user_id=user_id,
|
|
chat_session_id=chat_session_id,
|
|
description=name,
|
|
)
|
|
return RenameChatSessionResponse(new_name=name)
|
|
|
|
final_msg, history_msgs = create_chat_chain(
|
|
chat_session_id=chat_session_id, db_session=db_session
|
|
)
|
|
full_history = history_msgs + [final_msg]
|
|
|
|
try:
|
|
llm, _ = get_default_llms(
|
|
additional_headers=extract_headers(
|
|
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
|
)
|
|
)
|
|
except GenAIDisabledException:
|
|
# This may be longer than what the LLM tends to produce but is the most
|
|
# clear thing we can do
|
|
return RenameChatSessionResponse(new_name=full_history[0].message)
|
|
|
|
new_name = get_renamed_conversation_name(full_history=full_history, llm=llm)
|
|
|
|
update_chat_session(
|
|
db_session=db_session,
|
|
user_id=user_id,
|
|
chat_session_id=chat_session_id,
|
|
description=new_name,
|
|
)
|
|
|
|
return RenameChatSessionResponse(new_name=new_name)
|
|
|
|
|
|
@router.patch("/chat-session/{session_id}")
|
|
def patch_chat_session(
|
|
session_id: UUID,
|
|
chat_session_update_req: ChatSessionUpdateRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
user_id = user.id if user is not None else None
|
|
update_chat_session(
|
|
db_session=db_session,
|
|
user_id=user_id,
|
|
chat_session_id=session_id,
|
|
sharing_status=chat_session_update_req.sharing_status,
|
|
)
|
|
return None
|
|
|
|
|
|
@router.delete("/delete-all-chat-sessions")
|
|
def delete_all_chat_sessions(
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
try:
|
|
delete_all_chat_sessions_for_user(user=user, db_session=db_session)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.delete("/delete-chat-session/{session_id}")
|
|
def delete_chat_session_by_id(
|
|
session_id: UUID,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
user_id = user.id if user is not None else None
|
|
try:
|
|
delete_chat_session(user_id, session_id, db_session)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
async def is_connected(request: Request) -> Callable[[], bool]:
|
|
main_loop = asyncio.get_event_loop()
|
|
|
|
def is_connected_sync() -> bool:
|
|
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
|
|
try:
|
|
is_connected = not future.result(timeout=0.01)
|
|
return is_connected
|
|
except asyncio.TimeoutError:
|
|
logger.error("Asyncio timed out")
|
|
return True
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
logger.critical(
|
|
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
|
|
)
|
|
return True
|
|
|
|
return is_connected_sync
|
|
|
|
|
|
@router.post("/send-message")
|
|
def handle_new_chat_message(
|
|
chat_message_req: CreateChatMessageRequest,
|
|
request: Request,
|
|
user: User | None = Depends(current_chat_accesssible_user),
|
|
_rate_limit_check: None = Depends(check_token_rate_limits),
|
|
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
|
tenant_id: str = Depends(get_current_tenant_id),
|
|
) -> StreamingResponse:
|
|
"""
|
|
This endpoint is both used for all the following purposes:
|
|
- Sending a new message in the session
|
|
- Regenerating a message in the session (just send the same one again)
|
|
- Editing a message (similar to regenerating but sending a different message)
|
|
- Kicking off a seeded chat session (set `use_existing_user_message`)
|
|
|
|
Assumes that previous messages have been set as the latest to minimize overhead.
|
|
|
|
Args:
|
|
chat_message_req (CreateChatMessageRequest): Details about the new chat message.
|
|
request (Request): The current HTTP request context.
|
|
user (User | None): The current user, obtained via dependency injection.
|
|
_ (None): Rate limit check is run if user/group/global rate limits are enabled.
|
|
is_connected_func (Callable[[], bool]): Function to check client disconnection,
|
|
used to stop the streaming response if the client disconnects.
|
|
|
|
Returns:
|
|
StreamingResponse: Streams the response to the new chat message.
|
|
"""
|
|
logger.debug(f"Received new chat message: {chat_message_req.message}")
|
|
|
|
if (
|
|
not chat_message_req.message
|
|
and chat_message_req.prompt_id is not None
|
|
and not chat_message_req.use_existing_user_message
|
|
):
|
|
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
|
|
|
with get_session_with_tenant(tenant_id) as db_session:
|
|
create_milestone_and_report(
|
|
user=user,
|
|
distinct_id=user.email if user else tenant_id or "N/A",
|
|
event_type=MilestoneRecordType.RAN_QUERY,
|
|
properties=None,
|
|
db_session=db_session,
|
|
)
|
|
|
|
def stream_generator() -> Generator[str, None, None]:
|
|
try:
|
|
for packet in stream_chat_message(
|
|
new_msg_req=chat_message_req,
|
|
user=user,
|
|
litellm_additional_headers=extract_headers(
|
|
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
|
),
|
|
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
|
request.headers
|
|
),
|
|
is_connected=is_connected_func,
|
|
):
|
|
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
|
|
|
except Exception as e:
|
|
logger.exception("Error in chat message streaming")
|
|
yield json.dumps({"error": str(e)})
|
|
|
|
finally:
|
|
logger.debug("Stream generator finished")
|
|
|
|
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
|
|
|
|
|
@router.put("/set-message-as-latest")
|
|
def set_message_as_latest(
|
|
message_identifier: ChatMessageIdentifier,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
user_id = user.id if user is not None else None
|
|
|
|
chat_message = get_chat_message(
|
|
chat_message_id=message_identifier.message_id,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
)
|
|
|
|
set_as_latest_chat_message(
|
|
chat_message=chat_message,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
)
|
|
|
|
|
|
@router.post("/create-chat-message-feedback")
|
|
def create_chat_feedback(
|
|
feedback: ChatFeedbackRequest,
|
|
user: User | None = Depends(current_chat_accesssible_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
user_id = user.id if user else None
|
|
|
|
create_chat_message_feedback(
|
|
is_positive=feedback.is_positive,
|
|
feedback_text=feedback.feedback_text,
|
|
predefined_feedback=feedback.predefined_feedback,
|
|
chat_message_id=feedback.chat_message_id,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
)
|
|
|
|
|
|
@router.post("/document-search-feedback")
|
|
def create_search_feedback(
|
|
feedback: SearchFeedbackRequest,
|
|
_: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
"""This endpoint isn't protected - it does not check if the user has access to the document
|
|
Users could try changing boosts of arbitrary docs but this does not leak any data.
|
|
"""
|
|
create_doc_retrieval_feedback(
|
|
message_id=feedback.message_id,
|
|
document_id=feedback.document_id,
|
|
document_rank=feedback.document_rank,
|
|
clicked=feedback.click,
|
|
feedback=feedback.search_feedback,
|
|
db_session=db_session,
|
|
)
|
|
|
|
|
|
class MaxSelectedDocumentTokens(BaseModel):
|
|
max_tokens: int
|
|
|
|
|
|
@router.get("/max-selected-document-tokens")
|
|
def get_max_document_tokens(
|
|
persona_id: int,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> MaxSelectedDocumentTokens:
|
|
try:
|
|
persona = get_persona_by_id(
|
|
persona_id=persona_id,
|
|
user=user,
|
|
db_session=db_session,
|
|
is_for_edit=False,
|
|
)
|
|
except ValueError:
|
|
raise HTTPException(status_code=404, detail="Persona not found")
|
|
|
|
return MaxSelectedDocumentTokens(
|
|
max_tokens=compute_max_document_tokens_for_persona(
|
|
db_session=db_session,
|
|
persona=persona,
|
|
),
|
|
)
|
|
|
|
|
|
"""Endpoints for chat seeding"""
|
|
|
|
|
|
class ChatSeedRequest(BaseModel):
|
|
# standard chat session stuff
|
|
persona_id: int
|
|
prompt_id: int | None = None
|
|
|
|
# overrides / seeding
|
|
llm_override: LLMOverride | None = None
|
|
prompt_override: PromptOverride | None = None
|
|
description: str | None = None
|
|
message: str | None = None
|
|
|
|
# TODO: support this
|
|
# initial_message_retrieval_options: RetrievalDetails | None = None
|
|
|
|
|
|
class ChatSeedResponse(BaseModel):
|
|
redirect_url: str
|
|
|
|
|
|
@router.post("/seed-chat-session")
|
|
def seed_chat(
|
|
chat_seed_request: ChatSeedRequest,
|
|
# NOTE: realistically, this will be an API key not an actual user
|
|
_: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> ChatSeedResponse:
|
|
try:
|
|
new_chat_session = create_chat_session(
|
|
db_session=db_session,
|
|
description=chat_seed_request.description or "",
|
|
user_id=None, # this chat session is "unassigned" until a user visits the web UI
|
|
persona_id=chat_seed_request.persona_id,
|
|
llm_override=chat_seed_request.llm_override,
|
|
prompt_override=chat_seed_request.prompt_override,
|
|
)
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
|
|
|
|
if chat_seed_request.message is not None:
|
|
root_message = get_or_create_root_message(
|
|
chat_session_id=new_chat_session.id, db_session=db_session
|
|
)
|
|
llm, fast_llm = get_llms_for_persona(persona=new_chat_session.persona)
|
|
|
|
tokenizer = get_tokenizer(
|
|
model_name=llm.config.model_name,
|
|
provider_type=llm.config.model_provider,
|
|
)
|
|
token_count = len(tokenizer.encode(chat_seed_request.message))
|
|
|
|
create_new_chat_message(
|
|
chat_session_id=new_chat_session.id,
|
|
parent_message=root_message,
|
|
prompt_id=chat_seed_request.prompt_id
|
|
or (
|
|
new_chat_session.persona.prompts[0].id
|
|
if new_chat_session.persona.prompts
|
|
else None
|
|
),
|
|
message=chat_seed_request.message,
|
|
token_count=token_count,
|
|
message_type=MessageType.USER,
|
|
db_session=db_session,
|
|
)
|
|
|
|
return ChatSeedResponse(
|
|
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}&seeded=true"
|
|
)
|
|
|
|
|
|
class SeedChatFromSlackRequest(BaseModel):
|
|
chat_session_id: UUID
|
|
|
|
|
|
class SeedChatFromSlackResponse(BaseModel):
|
|
redirect_url: str
|
|
|
|
|
|
@router.post("/seed-chat-session-from-slack")
|
|
def seed_chat_from_slack(
|
|
chat_seed_request: SeedChatFromSlackRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> SeedChatFromSlackResponse:
|
|
slack_chat_session_id = chat_seed_request.chat_session_id
|
|
new_chat_session = duplicate_chat_session_for_user_from_slack(
|
|
db_session=db_session,
|
|
user=user,
|
|
chat_session_id=slack_chat_session_id,
|
|
)
|
|
|
|
add_chats_to_session_from_slack_thread(
|
|
db_session=db_session,
|
|
slack_chat_session_id=slack_chat_session_id,
|
|
new_chat_session_id=new_chat_session.id,
|
|
)
|
|
|
|
return SeedChatFromSlackResponse(
|
|
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}"
|
|
)
|
|
|
|
|
|
"""File upload"""
|
|
|
|
|
|
@router.post("/file")
|
|
def upload_files_for_chat(
|
|
files: list[UploadFile],
|
|
db_session: Session = Depends(get_session),
|
|
_: User | None = Depends(current_user),
|
|
) -> dict[str, list[FileDescriptor]]:
|
|
image_content_types = {"image/jpeg", "image/png", "image/webp"}
|
|
csv_content_types = {"text/csv"}
|
|
text_content_types = {
|
|
"text/plain",
|
|
"text/markdown",
|
|
"text/x-markdown",
|
|
"text/x-config",
|
|
"text/tab-separated-values",
|
|
"application/json",
|
|
"application/xml",
|
|
"text/xml",
|
|
"application/x-yaml",
|
|
}
|
|
document_content_types = {
|
|
"application/pdf",
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
"message/rfc822",
|
|
"application/epub+zip",
|
|
}
|
|
|
|
allowed_content_types = (
|
|
image_content_types.union(text_content_types)
|
|
.union(document_content_types)
|
|
.union(csv_content_types)
|
|
)
|
|
|
|
for file in files:
|
|
if not file.content_type:
|
|
raise HTTPException(status_code=400, detail="File content type is required")
|
|
|
|
if file.content_type not in allowed_content_types:
|
|
if file.content_type in image_content_types:
|
|
error_detail = "Unsupported image file type. Supported image types include .jpg, .jpeg, .png, .webp."
|
|
elif file.content_type in text_content_types:
|
|
error_detail = "Unsupported text file type. Supported text types include .txt, .csv, .md, .mdx, .conf, "
|
|
".log, .tsv."
|
|
elif file.content_type in csv_content_types:
|
|
error_detail = (
|
|
"Unsupported CSV file type. Supported CSV types include .csv."
|
|
)
|
|
else:
|
|
error_detail = (
|
|
"Unsupported document file type. Supported document types include .pdf, .docx, .pptx, .xlsx, "
|
|
".json, .xml, .yml, .yaml, .eml, .epub."
|
|
)
|
|
raise HTTPException(status_code=400, detail=error_detail)
|
|
|
|
if (
|
|
file.content_type in image_content_types
|
|
and file.size
|
|
and file.size > 20 * 1024 * 1024
|
|
):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="File size must be less than 20MB",
|
|
)
|
|
|
|
file_store = get_default_file_store(db_session)
|
|
|
|
file_info: list[tuple[str, str | None, ChatFileType]] = []
|
|
for file in files:
|
|
file_type = (
|
|
ChatFileType.IMAGE
|
|
if file.content_type in image_content_types
|
|
else ChatFileType.CSV
|
|
if file.content_type in csv_content_types
|
|
else ChatFileType.DOC
|
|
if file.content_type in document_content_types
|
|
else ChatFileType.PLAIN_TEXT
|
|
)
|
|
|
|
file_content = file.file.read() # Read the file content
|
|
|
|
if file_type == ChatFileType.IMAGE:
|
|
file_content_io = file.file
|
|
# NOTE: Image conversion to JPEG used to be enforced here.
|
|
# This was removed to:
|
|
# 1. Preserve original file content for downloads
|
|
# 2. Maintain transparency in formats like PNG
|
|
# 3. Ameliorate issue with file conversion
|
|
else:
|
|
file_content_io = io.BytesIO(file_content)
|
|
|
|
new_content_type = file.content_type
|
|
|
|
# Store the file normally
|
|
file_id = str(uuid.uuid4())
|
|
file_store.save_file(
|
|
file_name=file_id,
|
|
content=file_content_io,
|
|
display_name=file.filename,
|
|
file_origin=FileOrigin.CHAT_UPLOAD,
|
|
file_type=new_content_type or file_type.value,
|
|
)
|
|
|
|
# if the file is a doc, extract text and store that so we don't need
|
|
# to re-extract it every time we send a message
|
|
if file_type == ChatFileType.DOC:
|
|
extracted_text = extract_file_text(
|
|
file=io.BytesIO(file_content), # use the bytes we already read
|
|
file_name=file.filename or "",
|
|
)
|
|
text_file_id = str(uuid.uuid4())
|
|
file_store.save_file(
|
|
file_name=text_file_id,
|
|
content=io.BytesIO(extracted_text.encode()),
|
|
display_name=file.filename,
|
|
file_origin=FileOrigin.CHAT_UPLOAD,
|
|
file_type="text/plain",
|
|
)
|
|
# for DOC type, just return this for the FileDescriptor
|
|
# as we would always use this as the ID to attach to the
|
|
# message
|
|
file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT))
|
|
else:
|
|
file_info.append((file_id, file.filename, file_type))
|
|
|
|
return {
|
|
"files": [
|
|
{"id": file_id, "type": file_type, "name": file_name}
|
|
for file_id, file_name, file_type in file_info
|
|
]
|
|
}
|
|
|
|
|
|
@router.get("/file/{file_id:path}")
|
|
def fetch_chat_file(
|
|
file_id: str,
|
|
db_session: Session = Depends(get_session),
|
|
_: User | None = Depends(current_user),
|
|
) -> Response:
|
|
file_store = get_default_file_store(db_session)
|
|
file_record = file_store.read_file_record(file_id)
|
|
if not file_record:
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
original_file_name = file_record.display_name
|
|
if file_record.file_type.startswith(
|
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
|
):
|
|
# Check if a converted text file exists for .docx files
|
|
txt_file_name = docx_to_txt_filename(original_file_name)
|
|
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
|
|
txt_file_record = file_store.read_file_record(txt_file_id)
|
|
if txt_file_record:
|
|
file_record = txt_file_record
|
|
file_id = txt_file_id
|
|
|
|
media_type = file_record.file_type
|
|
file_io = file_store.read_file(file_id, mode="b")
|
|
|
|
return StreamingResponse(file_io, media_type=media_type)
|