2024-12-13 09:56:10 -08:00

278 lines
11 KiB
Python

import os
from datetime import datetime
from datetime import timezone
from io import BytesIO
from typing import Any
from typing import Optional
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from mypy_boto3_s3 import S3Client # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import BlobType
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
logger = setup_logger()
class BlobStorageConnector(LoadConnector, PollConnector):
def __init__(
self,
bucket_type: str,
bucket_name: str,
prefix: str = "",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.bucket_type: BlobType = BlobType(bucket_type)
self.bucket_name = bucket_name
self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/"
self.batch_size = batch_size
self.s3_client: Optional[S3Client] = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Checks for boto3 credentials based on the bucket type.
(1) R2: Access Key ID, Secret Access Key, Account ID
(2) S3: AWS Access Key ID, AWS Secret Access Key
(3) GOOGLE_CLOUD_STORAGE: Access Key ID, Secret Access Key, Project ID
(4) OCI_STORAGE: Namespace, Region, Access Key ID, Secret Access Key
For each bucket type, the method initializes the appropriate S3 client:
- R2: Uses Cloudflare R2 endpoint with S3v4 signature
- S3: Creates a standard boto3 S3 client
- GOOGLE_CLOUD_STORAGE: Uses Google Cloud Storage endpoint
- OCI_STORAGE: Uses Oracle Cloud Infrastructure Object Storage endpoint
Raises ConnectorMissingCredentialError if required credentials are missing.
Raises ValueError for unsupported bucket types.
"""
logger.debug(
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
)
if self.bucket_type == BlobType.R2:
if not all(
credentials.get(key)
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
):
raise ConnectorMissingCredentialError("Cloudflare R2")
self.s3_client = boto3.client(
"s3",
endpoint_url=f"https://{credentials['account_id']}.r2.cloudflarestorage.com",
aws_access_key_id=credentials["r2_access_key_id"],
aws_secret_access_key=credentials["r2_secret_access_key"],
region_name="auto",
config=Config(signature_version="s3v4"),
)
elif self.bucket_type == BlobType.S3:
if not all(
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],
aws_secret_access_key=credentials["aws_secret_access_key"],
)
self.s3_client = session.client("s3")
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
if not all(
credentials.get(key) for key in ["access_key_id", "secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
self.s3_client = boto3.client(
"s3",
endpoint_url="https://storage.googleapis.com",
aws_access_key_id=credentials["access_key_id"],
aws_secret_access_key=credentials["secret_access_key"],
region_name="auto",
)
elif self.bucket_type == BlobType.OCI_STORAGE:
if not all(
credentials.get(key)
for key in ["namespace", "region", "access_key_id", "secret_access_key"]
):
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
self.s3_client = boto3.client(
"s3",
endpoint_url=f"https://{credentials['namespace']}.compat.objectstorage.{credentials['region']}.oraclecloud.com",
aws_access_key_id=credentials["access_key_id"],
aws_secret_access_key=credentials["secret_access_key"],
region_name=credentials["region"],
)
else:
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
return None
def _download_object(self, key: str) -> bytes:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
object = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
return object["Body"].read()
# NOTE: Left in as may be useful for one-off access to documents and sharing across orgs.
# def _get_presigned_url(self, key: str) -> str:
# if self.s3_client is None:
# raise ConnectorMissingCredentialError("Blog storage")
# url = self.s3_client.generate_presigned_url(
# "get_object",
# Params={"Bucket": self.bucket_name, "Key": key},
# ExpiresIn=self.presign_length,
# )
# return url
def _get_blob_link(self, key: str) -> str:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
if self.bucket_type == BlobType.R2:
account_id = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
return f"https://{account_id}.r2.cloudflarestorage.com/{self.bucket_name}/{key}"
elif self.bucket_type == BlobType.S3:
region = self.s3_client.meta.region_name
return f"https://{self.bucket_name}.s3.{region}.amazonaws.com/{key}"
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
return f"https://storage.cloud.google.com/{self.bucket_name}/{key}"
elif self.bucket_type == BlobType.OCI_STORAGE:
namespace = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0]
region = self.s3_client.meta.region_name
return f"https://objectstorage.{region}.oraclecloud.com/n/{namespace}/b/{self.bucket_name}/o/{key}"
else:
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
def _yield_blob_objects(
self,
start: datetime,
end: datetime,
) -> GenerateDocumentsOutput:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
batch: list[Document] = []
for page in pages:
if "Contents" not in page:
continue
for obj in page["Contents"]:
if obj["Key"].endswith("/"):
continue
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
if not start <= last_modified <= end:
continue
downloaded_file = self._download_object(obj["Key"])
link = self._get_blob_link(obj["Key"])
name = os.path.basename(obj["Key"])
try:
text = extract_file_text(
BytesIO(downloaded_file),
file_name=name,
break_on_unprocessable=False,
)
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
sections=[Section(link=link, text=text)],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=name,
doc_updated_at=last_modified,
metadata={},
)
)
if len(batch) == self.batch_size:
yield batch
batch = []
except Exception as e:
logger.exception(
f"Error decoding object {obj['Key']} as UTF-8: {e}"
)
if batch:
yield batch
def load_from_state(self) -> GenerateDocumentsOutput:
logger.debug("Loading blob objects")
return self._yield_blob_objects(
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
end=datetime.now(timezone.utc),
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blob storage")
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
for batch in self._yield_blob_objects(start_datetime, end_datetime):
yield batch
return None
if __name__ == "__main__":
credentials_dict = {
"aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"),
"aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"),
}
# Initialize the connector
connector = BlobStorageConnector(
bucket_type=os.environ.get("BUCKET_TYPE") or "s3",
bucket_name=os.environ.get("BUCKET_NAME") or "test",
prefix="",
)
try:
connector.load_credentials(credentials_dict)
document_batch_generator = connector.load_from_state()
for document_batch in document_batch_generator:
print("First batch of documents:")
for doc in document_batch:
print(f"Document ID: {doc.id}")
print(f"Semantic Identifier: {doc.semantic_identifier}")
print(f"Source: {doc.source}")
print(f"Updated At: {doc.doc_updated_at}")
print("Sections:")
for section in doc.sections:
print(f" - Link: {section.link}")
print(f" - Text: {section.text[:100]}...")
print("---")
break
except ConnectorMissingCredentialError as e:
print(f"Error: {e}")
except Exception as e:
print(f"An unexpected error occurred: {e}")