From e56fd43ba6d9a2d9ccf792879f76064004e3a226 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sat, 5 Oct 2024 16:08:28 -0700 Subject: [PATCH] cors update (#2686) --- backend/shared_configs/configs.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index ea37b031c..e8b599b77 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -1,4 +1,5 @@ import os +from typing import List from urllib.parse import urlparse # Used for logging @@ -76,16 +77,32 @@ PRESERVED_SEARCH_FIELDS = [ ] -# CORS def validate_cors_origin(origin: str) -> None: parsed = urlparse(origin) if parsed.scheme not in ["http", "https"] or not parsed.netloc: raise ValueError(f"Invalid CORS origin: '{origin}'") -CORS_ALLOWED_ORIGIN = os.environ.get("CORS_ALLOWED_ORIGIN", "*").split(",") or ["*"] +# Examples of valid values for the environment variable: +# - "" (allow all origins) +# - "http://example.com" (single origin) +# - "http://example.com,https://example.org" (multiple origins) +# - "*" (allow all origins) +CORS_ALLOWED_ORIGIN_ENV = os.environ.get("CORS_ALLOWED_ORIGIN", "") -# Validate non-wildcard origins -for origin in CORS_ALLOWED_ORIGIN: - if origin != "*" and (stripped_origin := origin.strip()): - validate_cors_origin(stripped_origin) +# Explicitly declare the type of CORS_ALLOWED_ORIGIN +CORS_ALLOWED_ORIGIN: List[str] + +if CORS_ALLOWED_ORIGIN_ENV: + # Split the environment variable into a list of origins + CORS_ALLOWED_ORIGIN = [ + origin.strip() + for origin in CORS_ALLOWED_ORIGIN_ENV.split(",") + if origin.strip() + ] + # Validate each origin in the list + for origin in CORS_ALLOWED_ORIGIN: + validate_cors_origin(origin) +else: + # If the environment variable is empty, allow all origins + CORS_ALLOWED_ORIGIN = ["*"]