mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-12 12:53:01 +02:00
215 lines
7.6 KiB
Python
215 lines
7.6 KiB
Python
import os
|
|
from collections.abc import Iterator
|
|
from datetime import datetime
|
|
from datetime import timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from typing import IO
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
|
from onyx.configs.constants import DocumentSource
|
|
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
|
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
|
from onyx.connectors.interfaces import LoadConnector
|
|
from onyx.connectors.models import BasicExpertInfo
|
|
from onyx.connectors.models import Document
|
|
from onyx.connectors.models import Section
|
|
from onyx.db.engine import get_session_with_tenant
|
|
from onyx.file_processing.extract_file_text import detect_encoding
|
|
from onyx.file_processing.extract_file_text import extract_file_text
|
|
from onyx.file_processing.extract_file_text import get_file_ext
|
|
from onyx.file_processing.extract_file_text import is_text_file_extension
|
|
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
|
from onyx.file_processing.extract_file_text import load_files_from_zip
|
|
from onyx.file_processing.extract_file_text import read_pdf_file
|
|
from onyx.file_processing.extract_file_text import read_text_file
|
|
from onyx.file_store.file_store import get_default_file_store
|
|
from onyx.utils.logger import setup_logger
|
|
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
|
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def _read_files_and_metadata(
|
|
file_name: str,
|
|
db_session: Session,
|
|
) -> Iterator[tuple[str, IO, dict[str, Any]]]:
|
|
"""Reads the file into IO, in the case of a zip file, yields each individual
|
|
file contained within, also includes the metadata dict if packaged in the zip"""
|
|
extension = get_file_ext(file_name)
|
|
metadata: dict[str, Any] = {}
|
|
directory_path = os.path.dirname(file_name)
|
|
|
|
file_content = get_default_file_store(db_session).read_file(file_name, mode="b")
|
|
|
|
if extension == ".zip":
|
|
for file_info, file, metadata in load_files_from_zip(
|
|
file_content, ignore_dirs=True
|
|
):
|
|
yield os.path.join(directory_path, file_info.filename), file, metadata
|
|
elif is_valid_file_ext(extension):
|
|
yield file_name, file_content, metadata
|
|
else:
|
|
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
|
|
|
|
|
def _process_file(
|
|
file_name: str,
|
|
file: IO[Any],
|
|
metadata: dict[str, Any] | None = None,
|
|
pdf_pass: str | None = None,
|
|
) -> list[Document]:
|
|
extension = get_file_ext(file_name)
|
|
if not is_valid_file_ext(extension):
|
|
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
|
return []
|
|
|
|
file_metadata: dict[str, Any] = {}
|
|
|
|
if is_text_file_extension(file_name):
|
|
encoding = detect_encoding(file)
|
|
file_content_raw, file_metadata = read_text_file(
|
|
file, encoding=encoding, ignore_onyx_metadata=False
|
|
)
|
|
|
|
# Using the PDF reader function directly to pass in password cleanly
|
|
elif extension == ".pdf" and pdf_pass is not None:
|
|
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
|
|
|
|
else:
|
|
file_content_raw = extract_file_text(
|
|
file=file,
|
|
file_name=file_name,
|
|
break_on_unprocessable=True,
|
|
)
|
|
|
|
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
|
|
|
|
# add a prefix to avoid conflicts with other connectors
|
|
doc_id = f"FILE_CONNECTOR__{file_name}"
|
|
if metadata:
|
|
doc_id = metadata.get("document_id") or doc_id
|
|
|
|
# If this is set, we will show this in the UI as the "name" of the file
|
|
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
|
|
file_name
|
|
)
|
|
title = (
|
|
all_metadata["title"] or "" if "title" in all_metadata else file_display_name
|
|
)
|
|
|
|
time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc))
|
|
if isinstance(time_updated, str):
|
|
time_updated = time_str_to_utc(time_updated)
|
|
|
|
dt_str = all_metadata.get("doc_updated_at")
|
|
final_time_updated = time_str_to_utc(dt_str) if dt_str else time_updated
|
|
|
|
# Metadata tags separate from the Onyx specific fields
|
|
metadata_tags = {
|
|
k: v
|
|
for k, v in all_metadata.items()
|
|
if k
|
|
not in [
|
|
"document_id",
|
|
"time_updated",
|
|
"doc_updated_at",
|
|
"link",
|
|
"primary_owners",
|
|
"secondary_owners",
|
|
"filename",
|
|
"file_display_name",
|
|
"title",
|
|
"connector_type",
|
|
]
|
|
}
|
|
|
|
source_type_str = all_metadata.get("connector_type")
|
|
source_type = DocumentSource(source_type_str) if source_type_str else None
|
|
|
|
p_owner_names = all_metadata.get("primary_owners")
|
|
s_owner_names = all_metadata.get("secondary_owners")
|
|
p_owners = (
|
|
[BasicExpertInfo(display_name=name) for name in p_owner_names]
|
|
if p_owner_names
|
|
else None
|
|
)
|
|
s_owners = (
|
|
[BasicExpertInfo(display_name=name) for name in s_owner_names]
|
|
if s_owner_names
|
|
else None
|
|
)
|
|
|
|
return [
|
|
Document(
|
|
id=doc_id,
|
|
sections=[
|
|
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
|
|
],
|
|
source=source_type or DocumentSource.FILE,
|
|
semantic_identifier=file_display_name,
|
|
title=title,
|
|
doc_updated_at=final_time_updated,
|
|
primary_owners=p_owners,
|
|
secondary_owners=s_owners,
|
|
# currently metadata just houses tags, other stuff like owners / updated at have dedicated fields
|
|
metadata=metadata_tags,
|
|
)
|
|
]
|
|
|
|
|
|
class LocalFileConnector(LoadConnector):
|
|
def __init__(
|
|
self,
|
|
file_locations: list[Path | str],
|
|
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
) -> None:
|
|
self.file_locations = [Path(file_location) for file_location in file_locations]
|
|
self.batch_size = batch_size
|
|
self.tenant_id = tenant_id
|
|
self.pdf_pass: str | None = None
|
|
|
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
|
self.pdf_pass = credentials.get("pdf_password")
|
|
return None
|
|
|
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
|
documents: list[Document] = []
|
|
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
|
|
|
with get_session_with_tenant(self.tenant_id) as db_session:
|
|
for file_path in self.file_locations:
|
|
current_datetime = datetime.now(timezone.utc)
|
|
files = _read_files_and_metadata(
|
|
file_name=str(file_path), db_session=db_session
|
|
)
|
|
|
|
for file_name, file, metadata in files:
|
|
metadata["time_updated"] = metadata.get(
|
|
"time_updated", current_datetime
|
|
)
|
|
documents.extend(
|
|
_process_file(file_name, file, metadata, self.pdf_pass)
|
|
)
|
|
|
|
if len(documents) >= self.batch_size:
|
|
yield documents
|
|
documents = []
|
|
|
|
if documents:
|
|
yield documents
|
|
|
|
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
|
|
connector.load_credentials({"pdf_password": os.environ["PDF_PASSWORD"]})
|
|
|
|
document_batches = connector.load_from_state()
|
|
print(next(document_batches))
|