Merge pull request #9596 from engineeringpatrick/main

feat: add s3 key prefix support
This commit is contained in:
Timothy Jaeryang Baek 2025-02-09 13:05:41 -08:00 committed by GitHub
commit f25df15997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 8 deletions

View File

@ -660,6 +660,7 @@ S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID", None)
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY", None)
S3_REGION_NAME = os.environ.get("S3_REGION_NAME", None)
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", None)
S3_KEY_PREFIX = os.environ.get("S3_KEY_PREFIX", None)
S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL", None)
GCS_BUCKET_NAME = os.environ.get("GCS_BUCKET_NAME", None)

View File

@ -10,6 +10,7 @@ from open_webui.config import (
S3_ACCESS_KEY_ID,
S3_BUCKET_NAME,
S3_ENDPOINT_URL,
S3_KEY_PREFIX,
S3_REGION_NAME,
S3_SECRET_ACCESS_KEY,
GCS_BUCKET_NAME,
@ -93,34 +94,36 @@ class S3StorageProvider(StorageProvider):
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
)
self.bucket_name = S3_BUCKET_NAME
self.key_prefix = S3_KEY_PREFIX if S3_KEY_PREFIX else ""
def upload_file(self, file: BinaryIO, filename: str) -> Tuple[bytes, str]:
"""Handles uploading of the file to S3 storage."""
_, file_path = LocalStorageProvider.upload_file(file, filename)
try:
self.s3_client.upload_file(file_path, self.bucket_name, filename)
s3_key = os.path.join(self.key_prefix, filename)
self.s3_client.upload_file(file_path, self.bucket_name, s3_key)
return (
open(file_path, "rb").read(),
"s3://" + self.bucket_name + "/" + filename,
"s3://" + self.bucket_name + "/" + s3_key,
)
except ClientError as e:
raise RuntimeError(f"Error uploading file to S3: {e}")
def get_file(self, file_path: str) -> str:
"""Handles downloading of the file from S3 storage."""
try:
bucket_name, key = file_path.split("//")[1].split("/")
local_file_path = f"{UPLOAD_DIR}/{key}"
self.s3_client.download_file(bucket_name, key, local_file_path)
s3_key = self._extract_s3_key(file_path)
local_file_path = self._get_local_file_path(s3_key)
self.s3_client.download_file(self.bucket_name, s3_key, local_file_path)
return local_file_path
except ClientError as e:
raise RuntimeError(f"Error downloading file from S3: {e}")
def delete_file(self, file_path: str) -> None:
"""Handles deletion of the file from S3 storage."""
filename = file_path.split("/")[-1]
try:
self.s3_client.delete_object(Bucket=self.bucket_name, Key=filename)
s3_key = self._extract_s3_key(file_path)
self.s3_client.delete_object(Bucket=self.bucket_name, Key=s3_key)
except ClientError as e:
raise RuntimeError(f"Error deleting file from S3: {e}")
@ -133,6 +136,9 @@ class S3StorageProvider(StorageProvider):
response = self.s3_client.list_objects_v2(Bucket=self.bucket_name)
if "Contents" in response:
for content in response["Contents"]:
# Skip objects that were not uploaded from open-webui in the first place
if not content["Key"].startswith(self.key_prefix): continue
self.s3_client.delete_object(
Bucket=self.bucket_name, Key=content["Key"]
)
@ -142,6 +148,12 @@ class S3StorageProvider(StorageProvider):
# Always delete from local storage
LocalStorageProvider.delete_all_files()
# The s3 key is the name assigned to an object. It excludes the bucket name, but includes the internal path and the file name.
def _extract_s3_key(self, full_file_path: str) -> str:
return '/'.join(full_file_path.split("//")[1].split("/")[1:])
def _get_local_file_path(self, s3_key: str) -> str:
return f"{UPLOAD_DIR}/{s3_key.split('/')[-1]}"
class GCSStorageProvider(StorageProvider):
def __init__(self):