From a5d2f0d9ac28f84c35dcac94c9f3b0ba4f234ac6 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 31 Jan 2025 16:29:04 -0800 Subject: [PATCH] =?UTF-8?q?Fix=20airtable=20connector=20w/=20mt=20cloud=20?= =?UTF-8?q?+=20move=20telem=20logic=20to=20match=20new=20st=E2=80=A6=20(#3?= =?UTF-8?q?868)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix airtable connector w/ mt cloud + move telem logic to match new standard * Address Greptile comment * Small fixes/improvements * Revert back monitoring frequency * Small monitoring fix --- .../celery/tasks/monitoring/tasks.py | 8 +++++- .../connectors/airtable/airtable_connector.py | 24 +++++++++++------- backend/onyx/key_value_store/factory.py | 4 +-- backend/onyx/key_value_store/store.py | 8 +++--- backend/onyx/main.py | 5 +++- backend/onyx/utils/telemetry.py | 25 +++++++++++-------- backend/shared_configs/contextvars.py | 10 ++++++++ 7 files changed, 55 insertions(+), 29 deletions(-) diff --git a/backend/onyx/background/celery/tasks/monitoring/tasks.py b/backend/onyx/background/celery/tasks/monitoring/tasks.py index d2116c60c..037560841 100644 --- a/backend/onyx/background/celery/tasks/monitoring/tasks.py +++ b/backend/onyx/background/celery/tasks/monitoring/tasks.py @@ -39,6 +39,7 @@ from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import redis_lock_dump from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR _MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes @@ -657,6 +658,9 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None: - Syncing speed metrics - Worker status and task counts """ + if tenant_id is not None: + CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) + task_logger.info("Starting background monitoring") r = get_redis_client(tenant_id=tenant_id) @@ -688,11 +692,13 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None: metrics = metric_fn() for metric in metrics: # double check to make sure we aren't double-emitting metrics - if metric.key is not None and not _has_metric_been_emitted( + if metric.key is None or not _has_metric_been_emitted( redis_std, metric.key ): metric.log() metric.emit(tenant_id) + + if metric.key is not None: _mark_metric_as_emitted(redis_std, metric.key) task_logger.info("Successfully collected background metrics") diff --git a/backend/onyx/connectors/airtable/airtable_connector.py b/backend/onyx/connectors/airtable/airtable_connector.py index 8739c246d..aa05201ee 100644 --- a/backend/onyx/connectors/airtable/airtable_connector.py +++ b/backend/onyx/connectors/airtable/airtable_connector.py @@ -1,4 +1,6 @@ +import contextvars from concurrent.futures import as_completed +from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from io import BytesIO from typing import Any @@ -347,15 +349,19 @@ class AirtableConnector(LoadConnector): with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit batch tasks - future_to_record = { - executor.submit( - self._process_record, - record=record, - table_schema=table_schema, - primary_field_name=primary_field_name, - ): record - for record in batch_records - } + future_to_record: dict[Future, RecordDict] = {} + for record in batch_records: + # Capture the current context so that the thread gets the current tenant ID + current_context = contextvars.copy_context() + future_to_record[ + executor.submit( + current_context.run, + self._process_record, + record=record, + table_schema=table_schema, + primary_field_name=primary_field_name, + ) + ] = record # Wait for all tasks in this batch to complete for future in as_completed(future_to_record): diff --git a/backend/onyx/key_value_store/factory.py b/backend/onyx/key_value_store/factory.py index 77f8ea79f..c53f7ebac 100644 --- a/backend/onyx/key_value_store/factory.py +++ b/backend/onyx/key_value_store/factory.py @@ -2,7 +2,7 @@ from onyx.key_value_store.interface import KeyValueStore from onyx.key_value_store.store import PgRedisKVStore -def get_kv_store(tenant_id: str | None = None) -> KeyValueStore: +def get_kv_store() -> KeyValueStore: # In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in # It's read from the global thread level variable - return PgRedisKVStore(tenant_id=tenant_id) + return PgRedisKVStore() diff --git a/backend/onyx/key_value_store/store.py b/backend/onyx/key_value_store/store.py index 6db1b6ce1..f0811b7e5 100644 --- a/backend/onyx/key_value_store/store.py +++ b/backend/onyx/key_value_store/store.py @@ -18,7 +18,7 @@ from onyx.utils.logger import setup_logger from onyx.utils.special_types import JSON_ro from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA -from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR +from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @@ -28,10 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day class PgRedisKVStore(KeyValueStore): - def __init__( - self, redis_client: Redis | None = None, tenant_id: str | None = None - ) -> None: - self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get() + def __init__(self, redis_client: Redis | None = None) -> None: + self.tenant_id = get_current_tenant_id() # If no redis_client is provided, fall back to the context var if redis_client is not None: diff --git a/backend/onyx/main.py b/backend/onyx/main.py index 05150fee4..c8d0e2d14 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -109,7 +109,9 @@ from onyx.utils.variable_functionality import global_version from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable from shared_configs.configs import CORS_ALLOWED_ORIGIN from shared_configs.configs import MULTI_TENANT +from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import SENTRY_DSN +from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() @@ -212,7 +214,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: if not MULTI_TENANT: # We cache this at the beginning so there is no delay in the first telemetry - get_or_generate_uuid(tenant_id=None) + CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA) + get_or_generate_uuid() # If we are multi-tenant, we need to only set up initial public tables with Session(engine) as db_session: diff --git a/backend/onyx/utils/telemetry.py b/backend/onyx/utils/telemetry.py index 9e9bb1513..60c0bd4c2 100644 --- a/backend/onyx/utils/telemetry.py +++ b/backend/onyx/utils/telemetry.py @@ -1,3 +1,4 @@ +import contextvars import threading import uuid from enum import Enum @@ -41,7 +42,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str: return str(uuid.uuid5(uuid.NAMESPACE_X500, tenant_id)) -def get_or_generate_uuid(tenant_id: str | None) -> str: +def get_or_generate_uuid() -> str: # TODO: split out the whole "instance UUID" generation logic into a separate # utility function. Telemetry should not be aware at all of how the UUID is # generated/stored. @@ -52,7 +53,7 @@ def get_or_generate_uuid(tenant_id: str | None) -> str: if _CACHED_UUID is not None: return _CACHED_UUID - kv_store = get_kv_store(tenant_id=tenant_id) + kv_store = get_kv_store() try: _CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY)) @@ -63,18 +64,18 @@ def get_or_generate_uuid(tenant_id: str | None) -> str: return _CACHED_UUID -def _get_or_generate_instance_domain(tenant_id: str | None = None) -> str | None: # +def _get_or_generate_instance_domain() -> str | None: # global _CACHED_INSTANCE_DOMAIN if _CACHED_INSTANCE_DOMAIN is not None: return _CACHED_INSTANCE_DOMAIN - kv_store = get_kv_store(tenant_id=tenant_id) + kv_store = get_kv_store() try: _CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY)) except KvKeyNotFoundError: - with get_session_with_tenant(tenant_id=tenant_id) as db_session: + with get_session_with_tenant() as db_session: first_user = db_session.query(User).first() if first_user: _CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1] @@ -103,7 +104,7 @@ def optional_telemetry( customer_uuid = ( _get_or_generate_customer_id_mt(tenant_id) if MULTI_TENANT - else get_or_generate_uuid(tenant_id) + else get_or_generate_uuid() ) payload = { "data": data, @@ -115,9 +116,7 @@ def optional_telemetry( "is_cloud": MULTI_TENANT, } if ENTERPRISE_EDITION_ENABLED: - payload["instance_domain"] = _get_or_generate_instance_domain( - tenant_id - ) + payload["instance_domain"] = _get_or_generate_instance_domain() requests.post( _DANSWER_TELEMETRY_ENDPOINT, headers={"Content-Type": "application/json"}, @@ -128,8 +127,12 @@ def optional_telemetry( # This way it silences all thread level logging as well pass - # Run in separate thread to have minimal overhead in main flows - thread = threading.Thread(target=telemetry_logic, daemon=True) + # Run in separate thread with the same context as the current thread + # This is to ensure that the thread gets the current tenant ID + current_context = contextvars.copy_context() + thread = threading.Thread( + target=lambda: current_context.run(telemetry_logic), daemon=True + ) thread.start() except Exception: # Should never interfere with normal functions of Onyx diff --git a/backend/shared_configs/contextvars.py b/backend/shared_configs/contextvars.py index df66b141c..c166ca69c 100644 --- a/backend/shared_configs/contextvars.py +++ b/backend/shared_configs/contextvars.py @@ -6,3 +6,13 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar( "current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA ) + + +"""Utils related to contextvars""" + + +def get_current_tenant_id() -> str: + tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() + if tenant_id is None: + raise RuntimeError("Tenant ID is not set. This should never happen.") + return tenant_id