2024-12-13 09:56:10 -08:00

215 lines
7.6 KiB
Python

import os
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import IO
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.db.engine import get_session_with_tenant
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
def _read_files_and_metadata(
file_name: str,
db_session: Session,
) -> Iterator[tuple[str, IO, dict[str, Any]]]:
"""Reads the file into IO, in the case of a zip file, yields each individual
file contained within, also includes the metadata dict if packaged in the zip"""
extension = get_file_ext(file_name)
metadata: dict[str, Any] = {}
directory_path = os.path.dirname(file_name)
file_content = get_default_file_store(db_session).read_file(file_name, mode="b")
if extension == ".zip":
for file_info, file, metadata in load_files_from_zip(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), file, metadata
elif is_valid_file_ext(extension):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
def _process_file(
file_name: str,
file: IO[Any],
metadata: dict[str, Any] | None = None,
pdf_pass: str | None = None,
) -> list[Document]:
extension = get_file_ext(file_name)
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return []
file_metadata: dict[str, Any] = {}
if is_text_file_extension(file_name):
encoding = detect_encoding(file)
file_content_raw, file_metadata = read_text_file(
file, encoding=encoding, ignore_onyx_metadata=False
)
# Using the PDF reader function directly to pass in password cleanly
elif extension == ".pdf" and pdf_pass is not None:
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
else:
file_content_raw = extract_file_text(
file=file,
file_name=file_name,
break_on_unprocessable=True,
)
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
# add a prefix to avoid conflicts with other connectors
doc_id = f"FILE_CONNECTOR__{file_name}"
if metadata:
doc_id = metadata.get("document_id") or doc_id
# If this is set, we will show this in the UI as the "name" of the file
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
file_name
)
title = (
all_metadata["title"] or "" if "title" in all_metadata else file_display_name
)
time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc))
if isinstance(time_updated, str):
time_updated = time_str_to_utc(time_updated)
dt_str = all_metadata.get("doc_updated_at")
final_time_updated = time_str_to_utc(dt_str) if dt_str else time_updated
# Metadata tags separate from the Onyx specific fields
metadata_tags = {
k: v
for k, v in all_metadata.items()
if k
not in [
"document_id",
"time_updated",
"doc_updated_at",
"link",
"primary_owners",
"secondary_owners",
"filename",
"file_display_name",
"title",
"connector_type",
]
}
source_type_str = all_metadata.get("connector_type")
source_type = DocumentSource(source_type_str) if source_type_str else None
p_owner_names = all_metadata.get("primary_owners")
s_owner_names = all_metadata.get("secondary_owners")
p_owners = (
[BasicExpertInfo(display_name=name) for name in p_owner_names]
if p_owner_names
else None
)
s_owners = (
[BasicExpertInfo(display_name=name) for name in s_owner_names]
if s_owner_names
else None
)
return [
Document(
id=doc_id,
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=source_type or DocumentSource.FILE,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
# currently metadata just houses tags, other stuff like owners / updated at have dedicated fields
metadata=metadata_tags,
)
]
class LocalFileConnector(LoadConnector):
def __init__(
self,
file_locations: list[Path | str],
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.tenant_id = tenant_id
self.pdf_pass: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
return None
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
with get_session_with_tenant(self.tenant_id) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(
file_name=str(file_path), db_session=db_session
)
for file_name, file, metadata in files:
metadata["time_updated"] = metadata.get(
"time_updated", current_datetime
)
documents.extend(
_process_file(file_name, file, metadata, self.pdf_pass)
)
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
yield documents
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if __name__ == "__main__":
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
connector.load_credentials({"pdf_password": os.environ["PDF_PASSWORD"]})
document_batches = connector.load_from_state()
print(next(document_batches))