from functools import lru_cache from os import urandom from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import algorithms from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.ciphers import modes from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() @lru_cache(maxsize=1) def _get_trimmed_key(key: str) -> bytes: encoded_key = key.encode() key_length = len(encoded_key) if key_length < 16: raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short") elif key_length > 32: key = key[:32] elif key_length not in (16, 24, 32): valid_lengths = [16, 24, 32] key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))] return encoded_key def _encrypt_string(input_str: str) -> bytes: if not ENCRYPTION_KEY_SECRET: return input_str.encode() key = _get_trimmed_key(ENCRYPTION_KEY_SECRET) iv = urandom(16) padder = padding.PKCS7(algorithms.AES.block_size).padder() padded_data = padder.update(input_str.encode()) + padder.finalize() cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) encryptor = cipher.encryptor() encrypted_data = encryptor.update(padded_data) + encryptor.finalize() return iv + encrypted_data def _decrypt_bytes(input_bytes: bytes) -> str: if not ENCRYPTION_KEY_SECRET: return input_bytes.decode() key = _get_trimmed_key(ENCRYPTION_KEY_SECRET) iv = input_bytes[:16] encrypted_data = input_bytes[16:] cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) decryptor = cipher.decryptor() decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize() unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize() return decrypted_data.decode() def encrypt_string_to_bytes(input_str: str) -> bytes: versioned_encryption_fn = fetch_versioned_implementation( "onyx.utils.encryption", "_encrypt_string" ) return versioned_encryption_fn(input_str) def decrypt_bytes_to_string(input_bytes: bytes) -> str: versioned_decryption_fn = fetch_versioned_implementation( "onyx.utils.encryption", "_decrypt_bytes" ) return versioned_decryption_fn(input_bytes) def test_encryption() -> None: test_string = "Onyx is the BEST!" encrypted_bytes = encrypt_string_to_bytes(test_string) decrypted_string = decrypt_bytes_to_string(encrypted_bytes) if test_string != decrypted_string: raise RuntimeError("Encryption decryption test failed")