mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-06 13:09:39 +02:00
296 lines
9.4 KiB
Python
296 lines
9.4 KiB
Python
import base64
|
|
from collections.abc import Callable
|
|
from io import BytesIO
|
|
from typing import cast
|
|
from uuid import uuid4
|
|
|
|
import requests
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.configs.constants import FileOrigin
|
|
from onyx.db.engine import get_session_with_current_tenant
|
|
from onyx.db.models import ChatMessage
|
|
from onyx.db.models import UserFile
|
|
from onyx.db.models import UserFolder
|
|
from onyx.file_store.file_store import get_default_file_store
|
|
from onyx.file_store.models import ChatFileType
|
|
from onyx.file_store.models import FileDescriptor
|
|
from onyx.file_store.models import InMemoryChatFile
|
|
from onyx.utils.b64 import get_image_type
|
|
from onyx.utils.logger import setup_logger
|
|
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def user_file_id_to_plaintext_file_name(user_file_id: int) -> str:
|
|
"""Generate a consistent file name for storing plaintext content of a user file."""
|
|
return f"plaintext_{user_file_id}"
|
|
|
|
|
|
def store_user_file_plaintext(
|
|
user_file_id: int, plaintext_content: str, db_session: Session
|
|
) -> bool:
|
|
"""
|
|
Store plaintext content for a user file in the file store.
|
|
|
|
Args:
|
|
user_file_id: The ID of the user file
|
|
plaintext_content: The plaintext content to store
|
|
db_session: The database session
|
|
|
|
Returns:
|
|
bool: True if storage was successful, False otherwise
|
|
"""
|
|
# Skip empty content
|
|
if not plaintext_content:
|
|
return False
|
|
|
|
# Get plaintext file name
|
|
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
|
|
|
# Store the plaintext in the file store
|
|
file_store = get_default_file_store(db_session)
|
|
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
|
try:
|
|
file_store.save_file(
|
|
file_name=plaintext_file_name,
|
|
content=file_content,
|
|
display_name=f"Plaintext for user file {user_file_id}",
|
|
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
|
file_type="text/plain",
|
|
commit=False,
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
|
|
return False
|
|
|
|
|
|
def load_chat_file(
|
|
file_descriptor: FileDescriptor, db_session: Session
|
|
) -> InMemoryChatFile:
|
|
file_io = get_default_file_store(db_session).read_file(
|
|
file_descriptor["id"], mode="b"
|
|
)
|
|
return InMemoryChatFile(
|
|
file_id=file_descriptor["id"],
|
|
content=file_io.read(),
|
|
file_type=file_descriptor["type"],
|
|
filename=file_descriptor.get("name"),
|
|
)
|
|
|
|
|
|
def load_all_chat_files(
|
|
chat_messages: list[ChatMessage],
|
|
file_descriptors: list[FileDescriptor],
|
|
db_session: Session,
|
|
) -> list[InMemoryChatFile]:
|
|
file_descriptors_for_history: list[FileDescriptor] = []
|
|
for chat_message in chat_messages:
|
|
if chat_message.files:
|
|
file_descriptors_for_history.extend(chat_message.files)
|
|
|
|
files = cast(
|
|
list[InMemoryChatFile],
|
|
run_functions_tuples_in_parallel(
|
|
[
|
|
(load_chat_file, (file, db_session))
|
|
for file in file_descriptors + file_descriptors_for_history
|
|
]
|
|
),
|
|
)
|
|
return files
|
|
|
|
|
|
def load_user_folder(folder_id: int, db_session: Session) -> list[InMemoryChatFile]:
|
|
user_files = (
|
|
db_session.query(UserFile).filter(UserFile.folder_id == folder_id).all()
|
|
)
|
|
return [load_user_file(file.id, db_session) for file in user_files]
|
|
|
|
|
|
def load_user_file(file_id: int, db_session: Session) -> InMemoryChatFile:
|
|
user_file = db_session.query(UserFile).filter(UserFile.id == file_id).first()
|
|
if not user_file:
|
|
raise ValueError(f"User file with id {file_id} not found")
|
|
|
|
# Try to load plaintext version first
|
|
file_store = get_default_file_store(db_session)
|
|
plaintext_file_name = user_file_id_to_plaintext_file_name(file_id)
|
|
|
|
try:
|
|
file_io = file_store.read_file(plaintext_file_name, mode="b")
|
|
return InMemoryChatFile(
|
|
file_id=str(user_file.file_id),
|
|
content=file_io.read(),
|
|
file_type=ChatFileType.USER_KNOWLEDGE,
|
|
filename=user_file.name,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to load plaintext file {plaintext_file_name}, defaulting to original file: {e}"
|
|
)
|
|
# Fall back to original file if plaintext not available
|
|
file_io = file_store.read_file(user_file.file_id, mode="b")
|
|
return InMemoryChatFile(
|
|
file_id=str(user_file.file_id),
|
|
content=file_io.read(),
|
|
file_type=ChatFileType.USER_KNOWLEDGE,
|
|
filename=user_file.name,
|
|
)
|
|
|
|
|
|
def load_all_user_files(
|
|
user_file_ids: list[int],
|
|
user_folder_ids: list[int],
|
|
db_session: Session,
|
|
) -> list[InMemoryChatFile]:
|
|
return cast(
|
|
list[InMemoryChatFile],
|
|
run_functions_tuples_in_parallel(
|
|
[(load_user_file, (file_id, db_session)) for file_id in user_file_ids]
|
|
)
|
|
+ [
|
|
file
|
|
for folder_id in user_folder_ids
|
|
for file in load_user_folder(folder_id, db_session)
|
|
],
|
|
)
|
|
|
|
|
|
def load_all_user_file_files(
|
|
user_file_ids: list[int],
|
|
user_folder_ids: list[int],
|
|
db_session: Session,
|
|
) -> list[UserFile]:
|
|
user_files: list[UserFile] = []
|
|
for user_file_id in user_file_ids:
|
|
user_file = (
|
|
db_session.query(UserFile).filter(UserFile.id == user_file_id).first()
|
|
)
|
|
if user_file is not None:
|
|
user_files.append(user_file)
|
|
for user_folder_id in user_folder_ids:
|
|
user_files.extend(
|
|
db_session.query(UserFile)
|
|
.filter(UserFile.folder_id == user_folder_id)
|
|
.all()
|
|
)
|
|
return user_files
|
|
|
|
|
|
def save_file_from_url(url: str) -> str:
|
|
"""NOTE: using multiple sessions here, since this is often called
|
|
using multithreading. In practice, sharing a session has resulted in
|
|
weird errors."""
|
|
with get_session_with_current_tenant() as db_session:
|
|
response = requests.get(url)
|
|
response.raise_for_status()
|
|
|
|
unique_id = str(uuid4())
|
|
|
|
file_io = BytesIO(response.content)
|
|
file_store = get_default_file_store(db_session)
|
|
file_store.save_file(
|
|
file_name=unique_id,
|
|
content=file_io,
|
|
display_name="GeneratedImage",
|
|
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
|
file_type="image/png;base64",
|
|
commit=True,
|
|
)
|
|
return unique_id
|
|
|
|
|
|
def save_file_from_base64(base64_string: str) -> str:
|
|
with get_session_with_current_tenant() as db_session:
|
|
unique_id = str(uuid4())
|
|
file_store = get_default_file_store(db_session)
|
|
file_store.save_file(
|
|
file_name=unique_id,
|
|
content=BytesIO(base64.b64decode(base64_string)),
|
|
display_name="GeneratedImage",
|
|
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
|
file_type=get_image_type(base64_string),
|
|
commit=True,
|
|
)
|
|
return unique_id
|
|
|
|
|
|
def save_file(
|
|
url: str | None = None,
|
|
base64_data: str | None = None,
|
|
) -> str:
|
|
"""Save a file from either a URL or base64 encoded string.
|
|
|
|
Args:
|
|
url: URL to download file from
|
|
base64_data: Base64 encoded file data
|
|
|
|
Returns:
|
|
The unique ID of the saved file
|
|
|
|
Raises:
|
|
ValueError: If neither url nor base64_data is provided, or if both are provided
|
|
"""
|
|
if url is not None and base64_data is not None:
|
|
raise ValueError("Cannot specify both url and base64_data")
|
|
|
|
if url is not None:
|
|
return save_file_from_url(url)
|
|
elif base64_data is not None:
|
|
return save_file_from_base64(base64_data)
|
|
else:
|
|
raise ValueError("Must specify either url or base64_data")
|
|
|
|
|
|
def save_files(urls: list[str], base64_files: list[str]) -> list[str]:
|
|
# NOTE: be explicit about typing so that if we change things, we get notified
|
|
funcs: list[
|
|
tuple[
|
|
Callable[[str | None, str | None], str],
|
|
tuple[str | None, str | None],
|
|
]
|
|
] = [(save_file, (url, None)) for url in urls] + [
|
|
(save_file, (None, base64_file)) for base64_file in base64_files
|
|
]
|
|
|
|
return run_functions_tuples_in_parallel(funcs)
|
|
|
|
|
|
def load_all_persona_files_for_chat(
|
|
persona_id: int, db_session: Session
|
|
) -> tuple[list[InMemoryChatFile], list[int]]:
|
|
from onyx.db.models import Persona
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
persona = (
|
|
db_session.query(Persona)
|
|
.filter(Persona.id == persona_id)
|
|
.options(
|
|
joinedload(Persona.user_files),
|
|
joinedload(Persona.user_folders).joinedload(UserFolder.files),
|
|
)
|
|
.one()
|
|
)
|
|
|
|
persona_file_calls = [
|
|
(load_user_file, (user_file.id, db_session)) for user_file in persona.user_files
|
|
]
|
|
persona_loaded_files = run_functions_tuples_in_parallel(persona_file_calls)
|
|
|
|
persona_folder_files = []
|
|
persona_folder_file_ids = []
|
|
for user_folder in persona.user_folders:
|
|
folder_files = load_user_folder(user_folder.id, db_session)
|
|
persona_folder_files.extend(folder_files)
|
|
persona_folder_file_ids.extend([file.id for file in user_folder.files])
|
|
|
|
persona_files = list(persona_loaded_files) + persona_folder_files
|
|
persona_file_ids = [
|
|
file.id for file in persona.user_files
|
|
] + persona_folder_file_ids
|
|
|
|
return persona_files, persona_file_ids
|