danswer/backend/onyx/utils/variable_functionality.py
2024-12-13 09:56:10 -08:00

161 lines
5.2 KiB
Python

import functools
import importlib
import inspect
from typing import Any
from typing import TypeVar
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.utils.logger import setup_logger
logger = setup_logger()
class OnyxVersion:
def __init__(self) -> None:
self._is_ee = False
def set_ee(self) -> None:
self._is_ee = True
def is_ee_version(self) -> bool:
return self._is_ee
global_version = OnyxVersion()
def set_is_ee_based_on_env_variable() -> None:
if ENTERPRISE_EDITION_ENABLED and not global_version.is_ee_version():
logger.notice("Enterprise Edition enabled")
global_version.set_ee()
@functools.lru_cache(maxsize=128)
def fetch_versioned_implementation(module: str, attribute: str) -> Any:
"""
Fetches a versioned implementation of a specified attribute from a given module.
This function first checks if the application is running in an Enterprise Edition (EE)
context. If so, it attempts to import the attribute from the EE-specific module.
If the module or attribute is not found, it falls back to the default module or
raises the appropriate exception depending on the context.
Args:
module (str): The name of the module from which to fetch the attribute.
attribute (str): The name of the attribute to fetch from the module.
Returns:
Any: The fetched implementation of the attribute.
Raises:
ModuleNotFoundError: If the module cannot be found and the error is not related to
the Enterprise Edition fallback logic.
Logs:
Logs debug information about the fetching process and warnings if the versioned
implementation cannot be found or loaded.
"""
logger.debug("Fetching versioned implementation for %s.%s", module, attribute)
is_ee = global_version.is_ee_version()
module_full = f"ee.{module}" if is_ee else module
try:
return getattr(importlib.import_module(module_full), attribute)
except ModuleNotFoundError as e:
logger.warning(
"Failed to fetch versioned implementation for %s.%s: %s",
module_full,
attribute,
e,
)
if is_ee:
if "ee.onyx" not in str(e):
# If it's a non Onyx related import failure, this is likely because
# a dependent library has not been installed. Should raise this failure
# instead of letting the server start up
raise e
# Use the MIT version as a fallback, this allows us to develop MIT
# versions independently and later add additional EE functionality
# similar to feature flagging
return getattr(importlib.import_module(module), attribute)
raise
T = TypeVar("T")
def fetch_versioned_implementation_with_fallback(
module: str, attribute: str, fallback: T
) -> T:
"""
Attempts to fetch a versioned implementation of a specified attribute from a given module.
If the attempt fails (e.g., due to an import error or missing attribute), the function logs
a warning and returns the provided fallback implementation.
Args:
module (str): The name of the module from which to fetch the attribute.
attribute (str): The name of the attribute to fetch from the module.
fallback (T): The fallback implementation to return if fetching the attribute fails.
Returns:
T: The fetched implementation if successful, otherwise the provided fallback.
"""
try:
return fetch_versioned_implementation(module, attribute)
except Exception:
return fallback
def noop_fallback(*args: Any, **kwargs: Any) -> None:
"""
A no-op (no operation) fallback function that accepts any arguments but does nothing.
This is often used as a default or placeholder callback function.
Args:
*args (Any): Positional arguments, which are ignored.
**kwargs (Any): Keyword arguments, which are ignored.
Returns:
None
"""
def fetch_ee_implementation_or_noop(
module: str, attribute: str, noop_return_value: Any = None
) -> Any:
"""
Fetches an EE implementation if EE is enabled, otherwise returns a no-op function.
Raises an exception if EE is enabled but the fetch fails.
Args:
module (str): The name of the module from which to fetch the attribute.
attribute (str): The name of the attribute to fetch from the module.
Returns:
Any: The fetched EE implementation if successful and EE is enabled, otherwise a no-op function.
Raises:
Exception: If EE is enabled but the fetch fails.
"""
if not global_version.is_ee_version():
if inspect.iscoroutinefunction(noop_return_value):
async def async_noop(*args: Any, **kwargs: Any) -> Any:
return await noop_return_value(*args, **kwargs)
return async_noop
else:
def sync_noop(*args: Any, **kwargs: Any) -> Any:
return noop_return_value
return sync_noop
try:
return fetch_versioned_implementation(module, attribute)
except Exception as e:
logger.error(f"Failed to fetch implementation for {module}.{attribute}: {e}")
raise