mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-03 00:10:24 +02:00
278 lines
11 KiB
Python
278 lines
11 KiB
Python
import os
|
|
from datetime import datetime
|
|
from datetime import timezone
|
|
from io import BytesIO
|
|
from typing import Any
|
|
from typing import Optional
|
|
|
|
import boto3 # type: ignore
|
|
from botocore.client import Config # type: ignore
|
|
from mypy_boto3_s3 import S3Client # type: ignore
|
|
|
|
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
|
from onyx.configs.constants import BlobType
|
|
from onyx.configs.constants import DocumentSource
|
|
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
|
from onyx.connectors.interfaces import LoadConnector
|
|
from onyx.connectors.interfaces import PollConnector
|
|
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
|
from onyx.connectors.models import ConnectorMissingCredentialError
|
|
from onyx.connectors.models import Document
|
|
from onyx.connectors.models import Section
|
|
from onyx.file_processing.extract_file_text import extract_file_text
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
class BlobStorageConnector(LoadConnector, PollConnector):
|
|
def __init__(
|
|
self,
|
|
bucket_type: str,
|
|
bucket_name: str,
|
|
prefix: str = "",
|
|
batch_size: int = INDEX_BATCH_SIZE,
|
|
) -> None:
|
|
self.bucket_type: BlobType = BlobType(bucket_type)
|
|
self.bucket_name = bucket_name
|
|
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
|
|
self.batch_size = batch_size
|
|
self.s3_client: Optional[S3Client] = None
|
|
|
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
|
"""Checks for boto3 credentials based on the bucket type.
|
|
(1) R2: Access Key ID, Secret Access Key, Account ID
|
|
(2) S3: AWS Access Key ID, AWS Secret Access Key
|
|
(3) GOOGLE_CLOUD_STORAGE: Access Key ID, Secret Access Key, Project ID
|
|
(4) OCI_STORAGE: Namespace, Region, Access Key ID, Secret Access Key
|
|
|
|
For each bucket type, the method initializes the appropriate S3 client:
|
|
- R2: Uses Cloudflare R2 endpoint with S3v4 signature
|
|
- S3: Creates a standard boto3 S3 client
|
|
- GOOGLE_CLOUD_STORAGE: Uses Google Cloud Storage endpoint
|
|
- OCI_STORAGE: Uses Oracle Cloud Infrastructure Object Storage endpoint
|
|
|
|
Raises ConnectorMissingCredentialError if required credentials are missing.
|
|
Raises ValueError for unsupported bucket types.
|
|
"""
|
|
|
|
logger.debug(
|
|
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
|
|
)
|
|
|
|
if self.bucket_type == BlobType.R2:
|
|
if not all(
|
|
credentials.get(key)
|
|
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
|
|
):
|
|
raise ConnectorMissingCredentialError("Cloudflare R2")
|
|
self.s3_client = boto3.client(
|
|
"s3",
|
|
endpoint_url=f"https://{credentials['account_id']}.r2.cloudflarestorage.com",
|
|
aws_access_key_id=credentials["r2_access_key_id"],
|
|
aws_secret_access_key=credentials["r2_secret_access_key"],
|
|
region_name="auto",
|
|
config=Config(signature_version="s3v4"),
|
|
)
|
|
|
|
elif self.bucket_type == BlobType.S3:
|
|
if not all(
|
|
credentials.get(key)
|
|
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
|
):
|
|
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
|
|
|
session = boto3.Session(
|
|
aws_access_key_id=credentials["aws_access_key_id"],
|
|
aws_secret_access_key=credentials["aws_secret_access_key"],
|
|
)
|
|
self.s3_client = session.client("s3")
|
|
|
|
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
|
if not all(
|
|
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
|
|
):
|
|
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
|
|
|
self.s3_client = boto3.client(
|
|
"s3",
|
|
endpoint_url="https://storage.googleapis.com",
|
|
aws_access_key_id=credentials["access_key_id"],
|
|
aws_secret_access_key=credentials["secret_access_key"],
|
|
region_name="auto",
|
|
)
|
|
|
|
elif self.bucket_type == BlobType.OCI_STORAGE:
|
|
if not all(
|
|
credentials.get(key)
|
|
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
|
|
):
|
|
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
|
|
|
|
self.s3_client = boto3.client(
|
|
"s3",
|
|
endpoint_url=f"https://{credentials['namespace']}.compat.objectstorage.{credentials['region']}.oraclecloud.com",
|
|
aws_access_key_id=credentials["access_key_id"],
|
|
aws_secret_access_key=credentials["secret_access_key"],
|
|
region_name=credentials["region"],
|
|
)
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
|
|
|
return None
|
|
|
|
def _download_object(self, key: str) -> bytes:
|
|
if self.s3_client is None:
|
|
raise ConnectorMissingCredentialError("Blob storage")
|
|
object = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
|
|
return object["Body"].read()
|
|
|
|
# NOTE: Left in as may be useful for one-off access to documents and sharing across orgs.
|
|
# def _get_presigned_url(self, key: str) -> str:
|
|
# if self.s3_client is None:
|
|
# raise ConnectorMissingCredentialError("Blog storage")
|
|
|
|
# url = self.s3_client.generate_presigned_url(
|
|
# "get_object",
|
|
# Params={"Bucket": self.bucket_name, "Key": key},
|
|
# ExpiresIn=self.presign_length,
|
|
# )
|
|
# return url
|
|
|
|
def _get_blob_link(self, key: str) -> str:
|
|
if self.s3_client is None:
|
|
raise ConnectorMissingCredentialError("Blob storage")
|
|
|
|
if self.bucket_type == BlobType.R2:
|
|
account_id = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
|
|
return f"https://{account_id}.r2.cloudflarestorage.com/{self.bucket_name}/{key}"
|
|
|
|
elif self.bucket_type == BlobType.S3:
|
|
region = self.s3_client.meta.region_name
|
|
return f"https://{self.bucket_name}.s3.{region}.amazonaws.com/{key}"
|
|
|
|
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
|
return f"https://storage.cloud.google.com/{self.bucket_name}/{key}"
|
|
|
|
elif self.bucket_type == BlobType.OCI_STORAGE:
|
|
namespace = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
|
|
region = self.s3_client.meta.region_name
|
|
return f"https://objectstorage.{region}.oraclecloud.com/n/{namespace}/b/{self.bucket_name}/o/{key}"
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
|
|
|
def _yield_blob_objects(
|
|
self,
|
|
start: datetime,
|
|
end: datetime,
|
|
) -> GenerateDocumentsOutput:
|
|
if self.s3_client is None:
|
|
raise ConnectorMissingCredentialError("Blob storage")
|
|
|
|
paginator = self.s3_client.get_paginator("list_objects_v2")
|
|
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
|
|
|
batch: list[Document] = []
|
|
for page in pages:
|
|
if "Contents" not in page:
|
|
continue
|
|
|
|
for obj in page["Contents"]:
|
|
if obj["Key"].endswith("/"):
|
|
continue
|
|
|
|
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
|
|
|
if not start <= last_modified <= end:
|
|
continue
|
|
|
|
downloaded_file = self._download_object(obj["Key"])
|
|
link = self._get_blob_link(obj["Key"])
|
|
name = os.path.basename(obj["Key"])
|
|
|
|
try:
|
|
text = extract_file_text(
|
|
BytesIO(downloaded_file),
|
|
file_name=name,
|
|
break_on_unprocessable=False,
|
|
)
|
|
batch.append(
|
|
Document(
|
|
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
|
|
sections=[Section(link=link, text=text)],
|
|
source=DocumentSource(self.bucket_type.value),
|
|
semantic_identifier=name,
|
|
doc_updated_at=last_modified,
|
|
metadata={},
|
|
)
|
|
)
|
|
if len(batch) == self.batch_size:
|
|
yield batch
|
|
batch = []
|
|
|
|
except Exception as e:
|
|
logger.exception(
|
|
f"Error decoding object {obj['Key']} as UTF-8: {e}"
|
|
)
|
|
if batch:
|
|
yield batch
|
|
|
|
def load_from_state(self) -> GenerateDocumentsOutput:
|
|
logger.debug("Loading blob objects")
|
|
return self._yield_blob_objects(
|
|
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
|
|
end=datetime.now(timezone.utc),
|
|
)
|
|
|
|
def poll_source(
|
|
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
|
) -> GenerateDocumentsOutput:
|
|
if self.s3_client is None:
|
|
raise ConnectorMissingCredentialError("Blob storage")
|
|
|
|
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
|
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
|
|
|
for batch in self._yield_blob_objects(start_datetime, end_datetime):
|
|
yield batch
|
|
|
|
return None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
credentials_dict = {
|
|
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
|
|
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
|
}
|
|
|
|
# Initialize the connector
|
|
connector = BlobStorageConnector(
|
|
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
|
|
bucket_name=os.environ.get("BUCKET_NAME") or "test",
|
|
prefix="",
|
|
)
|
|
|
|
try:
|
|
connector.load_credentials(credentials_dict)
|
|
document_batch_generator = connector.load_from_state()
|
|
for document_batch in document_batch_generator:
|
|
print("First batch of documents:")
|
|
for doc in document_batch:
|
|
print(f"Document ID: {doc.id}")
|
|
print(f"Semantic Identifier: {doc.semantic_identifier}")
|
|
print(f"Source: {doc.source}")
|
|
print(f"Updated At: {doc.doc_updated_at}")
|
|
print("Sections:")
|
|
for section in doc.sections:
|
|
print(f" - Link: {section.link}")
|
|
print(f" - Text: {section.text[:100]}...")
|
|
print("---")
|
|
break
|
|
|
|
except ConnectorMissingCredentialError as e:
|
|
print(f"Error: {e}")
|
|
except Exception as e:
|
|
print(f"An unexpected error occurred: {e}")
|