From af588461d24139bd0059b5fd671d39bd9b6d2698 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 10 May 2024 18:50:18 -0700 Subject: [PATCH] Enable Encryption --- backend/danswer/configs/app_configs.py | 2 +- backend/ee/danswer/main.py | 3 + backend/ee/danswer/utils/encryption.py | 85 ++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 backend/ee/danswer/utils/encryption.py diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 222fddcb1e..01ce4c0527 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -46,7 +46,7 @@ DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED # Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive # information. This provides an extra layer of security on top of Postgres access controls # and is available in Danswer EE -ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") +ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or "" # Turn off mask if admin users should see full credentials for data connectors. MASK_CREDENTIAL_PREFIX = ( diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index 7d0f53c0e2..a982657f4a 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -38,6 +38,7 @@ from ee.danswer.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) from ee.danswer.server.user_group.api import router as user_group_router +from ee.danswer.utils.encryption import test_encryption logger = setup_logger() @@ -47,6 +48,8 @@ def get_ee_application() -> FastAPI: # Anything after the server startup will be running ee version global_version.set_ee() + test_encryption() + application = get_application() if AUTH_TYPE == AuthType.OIDC: diff --git a/backend/ee/danswer/utils/encryption.py b/backend/ee/danswer/utils/encryption.py new file mode 100644 index 0000000000..77b07c6d01 --- /dev/null +++ b/backend/ee/danswer/utils/encryption.py @@ -0,0 +1,85 @@ +from functools import lru_cache +from os import urandom + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.primitives.ciphers import algorithms +from cryptography.hazmat.primitives.ciphers import Cipher +from cryptography.hazmat.primitives.ciphers import modes + +from danswer.configs.app_configs import ENCRYPTION_KEY_SECRET +from danswer.utils.logger import setup_logger +from danswer.utils.variable_functionality import fetch_versioned_implementation + +logger = setup_logger() + + +@lru_cache(maxsize=1) +def _get_trimmed_key(key: str) -> bytes: + encoded_key = key.encode() + key_length = len(encoded_key) + if key_length < 16: + raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short") + elif key_length > 32: + key = key[:32] + elif key_length not in (16, 24, 32): + valid_lengths = [16, 24, 32] + key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))] + + return encoded_key + + +def _encrypt_string(input_str: str) -> bytes: + if not ENCRYPTION_KEY_SECRET: + return input_str.encode() + + key = _get_trimmed_key(ENCRYPTION_KEY_SECRET) + iv = urandom(16) + padder = padding.PKCS7(algorithms.AES.block_size).padder() + padded_data = padder.update(input_str.encode()) + padder.finalize() + + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + encryptor = cipher.encryptor() + encrypted_data = encryptor.update(padded_data) + encryptor.finalize() + + return iv + encrypted_data + + +def _decrypt_bytes(input_bytes: bytes) -> str: + if not ENCRYPTION_KEY_SECRET: + return input_bytes.decode() + + key = _get_trimmed_key(ENCRYPTION_KEY_SECRET) + iv = input_bytes[:16] + encrypted_data = input_bytes[16:] + + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) + decryptor = cipher.decryptor() + decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize() + + unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() + decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize() + + return decrypted_data.decode() + + +def encrypt_string_to_bytes(input_str: str) -> bytes: + versioned_encryption_fn = fetch_versioned_implementation( + "danswer.utils.encryption", "_encrypt_string" + ) + return versioned_encryption_fn(input_str) + + +def decrypt_bytes_to_string(input_bytes: bytes) -> str: + versioned_decryption_fn = fetch_versioned_implementation( + "danswer.utils.encryption", "_decrypt_bytes" + ) + return versioned_decryption_fn(input_bytes) + + +def test_encryption(): + test_string = "Danswer is the BEST!" + encrypted_bytes = encrypt_string_to_bytes(test_string) + decrypted_string = decrypt_bytes_to_string(encrypted_bytes) + if test_string != decrypted_string: + raise RuntimeError("Encryption decryption test failed")