danswer/backend/ee/onyx/utils/encryption.py
2024-12-13 09:56:10 -08:00

86 lines
2.8 KiB
Python

from functools import lru_cache
from os import urandom
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import algorithms
from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers import modes
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
@lru_cache(maxsize=1)
def _get_trimmed_key(key: str) -> bytes:
encoded_key = key.encode()
key_length = len(encoded_key)
if key_length < 16:
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
elif key_length > 32:
key = key[:32]
elif key_length not in (16, 24, 32):
valid_lengths = [16, 24, 32]
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
return encoded_key
def _encrypt_string(input_str: str) -> bytes:
if not ENCRYPTION_KEY_SECRET:
return input_str.encode()
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
iv = urandom(16)
padder = padding.PKCS7(algorithms.AES.block_size).padder()
padded_data = padder.update(input_str.encode()) + padder.finalize()
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
return iv + encrypted_data
def _decrypt_bytes(input_bytes: bytes) -> str:
if not ENCRYPTION_KEY_SECRET:
return input_bytes.decode()
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
iv = input_bytes[:16]
encrypted_data = input_bytes[16:]
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
decryptor = cipher.decryptor()
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
return decrypted_data.decode()
def encrypt_string_to_bytes(input_str: str) -> bytes:
versioned_encryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_encrypt_string"
)
return versioned_encryption_fn(input_str)
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
versioned_decryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_decrypt_bytes"
)
return versioned_decryption_fn(input_bytes)
def test_encryption() -> None:
test_string = "Onyx is the BEST!"
encrypted_bytes = encrypt_string_to_bytes(test_string)
decrypted_string = decrypt_bytes_to_string(encrypted_bytes)
if test_string != decrypted_string:
raise RuntimeError("Encryption decryption test failed")