Port File Store from Volume to PG (#1241)

This commit is contained in:
Yuhong Sun 2024-03-21 20:10:08 -07:00 committed by GitHub
parent 8dbe5cbaa6
commit c28a95e367
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 315 additions and 59 deletions

View File

@ -0,0 +1,28 @@
"""PG File Store
Revision ID: 4738e4b3bae1
Revises: e91df4e935ef
Create Date: 2024-03-20 18:53:32.461518
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4738e4b3bae1"
down_revision = "e91df4e935ef"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"file_store",
sa.Column("file_name", sa.String(), nullable=False),
sa.Column("lobj_oid", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("file_name"),
)
def downgrade() -> None:
op.drop_table("file_store")

View File

@ -1,7 +1,9 @@
"""Private Personas DocumentSets
Revision ID: e91df4e935ef
Revises: 91fd3b470d1a
Create Date: 2024-03-17 11:47:24.675881
"""
import fastapi_users_db_sqlalchemy
from alembic import op

View File

@ -226,8 +226,4 @@ celery_app.conf.beat_schedule = {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
},
"clean-old-temp-files": {
"task": "clean_old_temp_files_task",
"schedule": timedelta(minutes=30),
},
}

View File

@ -2,8 +2,7 @@ import json
import os
import re
import zipfile
from collections.abc import Generator
from pathlib import Path
from collections.abc import Iterator
from typing import Any
from typing import IO
@ -78,11 +77,11 @@ def is_macos_resource_fork_file(file_name: str) -> bool:
# to the zip file. This file should contain a list of objects with the following format:
# [{ "filename": "file1.txt", "link": "https://example.com/file1.txt" }]
def load_files_from_zip(
zip_location: str | Path,
zip_file_io: IO,
ignore_macos_resource_fork_files: bool = True,
ignore_dirs: bool = True,
) -> Generator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]], None, None]:
with zipfile.ZipFile(zip_location, "r") as zip_file:
) -> Iterator[tuple[zipfile.ZipInfo, IO[Any], dict[str, Any]]]:
with zipfile.ZipFile(zip_file_io, "r") as zip_file:
zip_metadata = {}
try:
metadata_file_info = zip_file.getinfo(".danswer_metadata.json")
@ -109,18 +108,19 @@ def load_files_from_zip(
yield file_info, file, zip_metadata.get(file_info.filename, {})
def detect_encoding(file_path: str | Path) -> str:
with open(file_path, "rb") as file:
raw_data = file.read(50000) # Read a portion of the file to guess encoding
return chardet.detect(raw_data)["encoding"] or "utf-8"
def detect_encoding(file: IO[bytes]) -> str:
raw_data = file.read(50000)
encoding = chardet.detect(raw_data)["encoding"] or "utf-8"
file.seek(0)
return encoding
def read_file(
file_reader: IO[Any], encoding: str = "utf-8", errors: str = "replace"
file: IO, encoding: str = "utf-8", errors: str = "replace"
) -> tuple[str, dict]:
metadata = {}
file_content_raw = ""
for ind, line in enumerate(file_reader):
for ind, line in enumerate(file):
try:
line = line.decode(encoding) if isinstance(line, bytes) else line
except UnicodeDecodeError:

View File

@ -1,11 +1,13 @@
import os
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import IO
from sqlalchemy.orm import Session
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.file_utils import detect_encoding
@ -20,37 +22,40 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _open_files_at_location(
file_path: str | Path,
) -> Generator[tuple[str, IO[Any], dict[str, Any]], Any, None]:
extension = get_file_ext(file_path)
def _read_files_and_metadata(
file_name: str,
db_session: Session,
) -> Iterator[tuple[str, IO, dict[str, Any]]]:
"""Reads the file into IO, in the case of a zip file, yields each individual
file contained within, also includes the metadata dict if packaged in the zip"""
extension = get_file_ext(file_name)
metadata: dict[str, Any] = {}
directory_path = os.path.dirname(file_name)
file_content = get_default_file_store(db_session).read_file(file_name, mode="b")
if extension == ".zip":
for file_info, file, metadata in load_files_from_zip(
file_path, ignore_dirs=True
file_content, ignore_dirs=True
):
yield file_info.filename, file, metadata
elif extension in [".txt", ".md", ".mdx"]:
encoding = detect_encoding(file_path)
with open(file_path, "r", encoding=encoding, errors="replace") as file:
yield os.path.basename(file_path), file, metadata
elif extension == ".pdf":
with open(file_path, "rb") as file:
yield os.path.basename(file_path), file, metadata
yield os.path.join(directory_path, file_info.filename), file, metadata
elif extension in [".txt", ".md", ".mdx", ".pdf"]:
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_path}' with extension '{extension}'")
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
def _process_file(
file_name: str,
file: IO[Any],
metadata: dict[str, Any] = {},
metadata: dict[str, Any] | None = None,
pdf_pass: str | None = None,
) -> list[Document]:
extension = get_file_ext(file_name)
@ -65,8 +70,9 @@ def _process_file(
file=file, file_name=file_name, pdf_pass=pdf_pass
)
else:
file_content_raw, file_metadata = read_file(file)
all_metadata = {**metadata, **file_metadata}
encoding = detect_encoding(file)
file_content_raw, file_metadata = read_file(file, encoding=encoding)
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
# If this is set, we will show this in the UI as the "name" of the file
file_display_name_override = all_metadata.get("file_display_name")
@ -114,7 +120,8 @@ def _process_file(
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=DocumentSource.FILE,
semantic_identifier=file_display_name_override or file_name,
semantic_identifier=file_display_name_override
or os.path.basename(file_name),
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
@ -140,24 +147,27 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
for file_location in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _open_files_at_location(file_location)
for file_name, file, metadata in files:
metadata["time_updated"] = metadata.get(
"time_updated", current_datetime
)
documents.extend(
_process_file(file_name, file, metadata, self.pdf_pass)
with Session(get_sqlalchemy_engine()) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(
file_name=str(file_path), db_session=db_session
)
if len(documents) >= self.batch_size:
yield documents
documents = []
for file_name, file, metadata in files:
metadata["time_updated"] = metadata.get(
"time_updated", current_datetime
)
documents.extend(
_process_file(file_name, file, metadata, self.pdf_pass)
)
if documents:
yield documents
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
yield documents
if __name__ == "__main__":

View File

@ -388,7 +388,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorugh
(1) A credential which holds a token acquired via a user going thorough
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.

View File

@ -5,6 +5,7 @@ from typing import cast
from bs4 import BeautifulSoup
from bs4 import Tag
from sqlalchemy.orm import Session
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
@ -15,6 +16,8 @@ from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -66,8 +69,13 @@ class GoogleSitesConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with Session(get_sqlalchemy_engine()) as db_session:
file_content_io = get_default_file_store(db_session).read_file(
self.zip_path, mode="b"
)
# load the HTML files
files = load_files_from_zip(self.zip_path)
files = load_files_from_zip(file_content_io)
count = 0
for file_info, file_io, _metadata in files:
# skip non-published files

View File

@ -149,6 +149,10 @@ class WebConnector(LoadConnector):
self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url))
elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD:
logger.warning(
"This is not a UI supported Web Connector flow, "
"are you sure you want to do this?"
)
self.to_visit_list = _read_urls_file(base_url)
else:

View File

@ -0,0 +1,96 @@
from abc import ABC
from abc import abstractmethod
from typing import IO
from sqlalchemy.orm import Session
from danswer.db.pg_file_store import create_populate_lobj
from danswer.db.pg_file_store import delete_lobj_by_id
from danswer.db.pg_file_store import delete_pgfilestore_by_file_name
from danswer.db.pg_file_store import get_pgfilestore_by_file_name
from danswer.db.pg_file_store import read_lobj
from danswer.db.pg_file_store import upsert_pgfilestore
class FileStore(ABC):
"""
An abstraction for storing files and large binary objects.
"""
@abstractmethod
def save_file(self, file_name: str, content: IO) -> None:
"""
Save a file to the blob store
Parameters:
- connector_name: Name of the CC-Pair (as specified by the user in the UI)
- file_name: Name of the file to save
- content: Contents of the file
"""
raise NotImplementedError
@abstractmethod
def read_file(self, file_name: str, mode: str | None) -> IO:
"""
Read the content of a given file by the name
Parameters:
- file_name: Name of file to read
Returns:
Contents of the file and metadata dict
"""
@abstractmethod
def delete_file(self, file_name: str) -> None:
"""
Delete a file by its name.
Parameters:
- file_name: Name of file to delete
"""
class PostgresBackedFileStore(FileStore):
def __init__(self, db_session: Session):
self.db_session = db_session
def save_file(self, file_name: str, content: IO) -> None:
try:
# The large objects in postgres are saved as special objects can can be listed with
# SELECT * FROM pg_largeobject_metadata;
obj_id = create_populate_lobj(content=content, db_session=self.db_session)
upsert_pgfilestore(
file_name=file_name, lobj_oid=obj_id, db_session=self.db_session
)
self.db_session.commit()
except Exception:
self.db_session.rollback()
raise
def read_file(self, file_name: str, mode: str | None = None) -> IO:
file_record = get_pgfilestore_by_file_name(
file_name=file_name, db_session=self.db_session
)
return read_lobj(
lobj_oid=file_record.lobj_oid, db_session=self.db_session, mode=mode
)
def delete_file(self, file_name: str) -> None:
try:
file_record = get_pgfilestore_by_file_name(
file_name=file_name, db_session=self.db_session
)
delete_lobj_by_id(file_record.lobj_oid, db_session=self.db_session)
delete_pgfilestore_by_file_name(
file_name=file_name, db_session=self.db_session
)
self.db_session.commit()
except Exception:
self.db_session.rollback()
raise
def get_default_file_store(db_session: Session) -> FileStore:
# The only supported file store now is the Postgres File Store
return PostgresBackedFileStore(db_session=db_session)

View File

@ -859,3 +859,9 @@ class KVStore(Base):
key: Mapped[str] = mapped_column(String, primary_key=True)
value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=False)
class PGFileStore(Base):
__tablename__ = "file_store"
file_name = mapped_column(String, primary_key=True)
lobj_oid = mapped_column(Integer, nullable=False)

View File

@ -0,0 +1,93 @@
from io import BytesIO
from typing import IO
from psycopg2.extensions import connection
from sqlalchemy.orm import Session
from danswer.db.models import PGFileStore
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_pg_conn_from_session(db_session: Session) -> connection:
return db_session.connection().connection.connection # type: ignore
def create_populate_lobj(
content: IO,
db_session: Session,
) -> int:
"""Note, this does not commit the changes to the DB
This is because the commit should happen with the PGFileStore row creation
That step finalizes both the Large Object and the table tracking it
"""
pg_conn = get_pg_conn_from_session(db_session)
large_object = pg_conn.lobject()
large_object.write(content.read())
large_object.close()
return large_object.oid
def read_lobj(lobj_oid: int, db_session: Session, mode: str | None = None) -> IO:
pg_conn = get_pg_conn_from_session(db_session)
large_object = (
pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid)
)
return BytesIO(large_object.read())
def delete_lobj_by_id(
lobj_oid: int,
db_session: Session,
) -> None:
pg_conn = get_pg_conn_from_session(db_session)
pg_conn.lobject(lobj_oid).unlink()
def upsert_pgfilestore(
file_name: str, lobj_oid: int, db_session: Session, commit: bool = False
) -> PGFileStore:
pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first()
if pgfilestore:
try:
# This should not happen in normal execution
delete_lobj_by_id(lobj_oid=pgfilestore.lobj_oid, db_session=db_session)
except Exception:
# If the delete fails as well, the large object doesn't exist anyway and even if it
# fails to delete, it's not too terrible as most files sizes are insignificant
logger.error(
f"Failed to delete large object with oid {pgfilestore.lobj_oid}"
)
pgfilestore.lobj_oid = lobj_oid
else:
pgfilestore = PGFileStore(file_name=file_name, lobj_oid=lobj_oid)
db_session.add(pgfilestore)
if commit:
db_session.commit()
return pgfilestore
def get_pgfilestore_by_file_name(
file_name: str,
db_session: Session,
) -> PGFileStore:
pgfilestore = db_session.query(PGFileStore).filter_by(file_name=file_name).first()
if not pgfilestore:
raise RuntimeError(f"File by name {file_name} does not exist or was deleted")
return pgfilestore
def delete_pgfilestore_by_file_name(
file_name: str,
db_session: Session,
) -> None:
db_session.query(PGFileStore).filter_by(file_name=file_name).delete()

View File

@ -1,3 +1,5 @@
import os
import uuid
from typing import cast
from fastapi import APIRouter
@ -13,7 +15,6 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_utils import get_deletion_status
from danswer.configs.constants import DocumentSource
from danswer.connectors.file.utils import write_temp_files
from danswer.connectors.gmail.connector_auth import delete_gmail_service_account_key
from danswer.connectors.gmail.connector_auth import delete_google_app_gmail_cred
from danswer.connectors.gmail.connector_auth import get_gmail_auth_url
@ -57,6 +58,7 @@ from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_document_cnts_for_cc_pairs
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.file_store import get_default_file_store
from danswer.db.index_attempt import cancel_indexing_attempts_for_connector
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import create_index_attempt
@ -335,18 +337,23 @@ def admin_google_drive_auth(
@router.post("/admin/connector/file/upload")
def upload_files(
files: list[UploadFile], _: User = Depends(current_admin_user)
files: list[UploadFile],
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FileUploadResponse:
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
try:
file_paths = write_temp_files(
[(cast(str, file.filename), file.file) for file in files]
)
file_store = get_default_file_store(db_session)
deduped_file_paths = []
for file in files:
file_path = os.path.join(str(uuid.uuid4()), cast(str, file.filename))
deduped_file_paths.append(file_path)
file_store.save_file(file_name=file_path, content=file.file)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return FileUploadResponse(file_paths=file_paths)
return FileUploadResponse(file_paths=deduped_file_paths)
@router.get("/admin/connector/indexing-status")

View File

@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import GEN_AI_DETECTED_MODEL
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
@ -21,6 +22,7 @@ from danswer.db.engine import get_session
from danswer.db.feedback import fetch_docs_ranked_by_boost
from danswer.db.feedback import update_document_boost
from danswer.db.feedback import update_document_hidden
from danswer.db.file_store import get_default_file_store
from danswer.db.models import User
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
@ -254,3 +256,9 @@ def create_deletion_attempt_for_connector_id(
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(connector_id=connector_id, credential_id=credential_id),
)
if cc_pair.connector.source == DocumentSource.FILE:
connector = cc_pair.connector
file_store = get_default_file_store(db_session)
for file_name in connector.connector_specific_config["file_locations"]:
file_store.delete_file(file_name)

View File

@ -3,7 +3,6 @@ import importlib
from typing import Any
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
logger = setup_logger()
@ -23,7 +22,6 @@ class DanswerVersion:
global_version = DanswerVersion()
@log_function_time(print_only=True, include_args=True)
@functools.lru_cache(maxsize=128)
def fetch_versioned_implementation(module: str, attribute: str) -> Any:
logger.info("Fetching versioned implementation for %s.%s", module, attribute)