Pass in tenant_id to kv_store in monitoring job

This commit is contained in:
Weves 2025-01-20 13:29:26 -08:00 committed by Chris Weaver
parent cc4953b560
commit 1378364686
4 changed files with 26 additions and 24 deletions

View File

@ -2,7 +2,7 @@ from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.store import PgRedisKVStore
def get_kv_store() -> KeyValueStore:
def get_kv_store(tenant_id: str | None = None) -> KeyValueStore:
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
# It's read from the global thread level variable
return PgRedisKVStore()
return PgRedisKVStore(tenant_id=tenant_id)

View File

@ -31,27 +31,27 @@ class PgRedisKVStore(KeyValueStore):
def __init__(
self, redis_client: Redis | None = None, tenant_id: str | None = None
) -> None:
self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
# If no redis_client is provided, fall back to the context var
if redis_client is not None:
self.redis_client = redis_client
else:
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
self.redis_client = get_redis_client(tenant_id=self.tenant_id)
@contextmanager
def get_session(self) -> Iterator[Session]:
def _get_session(self) -> Iterator[Session]:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
raise HTTPException(
status_code=401, detail="User must authenticate"
)
if not is_valid_schema_name(tenant_id):
if not is_valid_schema_name(self.tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Set the search_path to the tenant's schema
session.execute(text(f'SET search_path = "{tenant_id}"'))
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
yield session
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
@ -66,7 +66,7 @@ class PgRedisKVStore(KeyValueStore):
encrypted_val = val if encrypt else None
plain_val = val if not encrypt else None
with self.get_session() as session:
with self._get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
@ -88,7 +88,7 @@ class PgRedisKVStore(KeyValueStore):
except Exception as e:
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
with self.get_session() as session:
with self._get_session() as session:
obj = session.query(KVStore).filter_by(key=key).first()
if not obj:
raise KvKeyNotFoundError
@ -113,7 +113,7 @@ class PgRedisKVStore(KeyValueStore):
except Exception as e:
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
with self.get_session() as session:
with self._get_session() as session:
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
if result == 0:
raise KvKeyNotFoundError

View File

@ -212,7 +212,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid()
get_or_generate_uuid(tenant_id=None)
# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:

View File

@ -11,7 +11,7 @@ from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.configs.constants import KV_CUSTOMER_UUID_KEY
from onyx.configs.constants import KV_INSTANCE_DOMAIN_KEY
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.engine import get_session_with_tenant
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.models import User
from onyx.key_value_store.factory import get_kv_store
@ -41,7 +41,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
return str(uuid.uuid5(uuid.NAMESPACE_X500, tenant_id))
def get_or_generate_uuid(tenant_id: str | None = None) -> str:
def get_or_generate_uuid(tenant_id: str | None) -> str:
# TODO: split out the whole "instance UUID" generation logic into a separate
# utility function. Telemetry should not be aware at all of how the UUID is
# generated/stored.
@ -52,7 +52,7 @@ def get_or_generate_uuid(tenant_id: str | None = None) -> str:
if _CACHED_UUID is not None:
return _CACHED_UUID
kv_store = get_kv_store()
kv_store = get_kv_store(tenant_id=tenant_id)
try:
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
@ -63,18 +63,18 @@ def get_or_generate_uuid(tenant_id: str | None = None) -> str:
return _CACHED_UUID
def _get_or_generate_instance_domain() -> str | None: #
def _get_or_generate_instance_domain(tenant_id: str | None = None) -> str | None: #
global _CACHED_INSTANCE_DOMAIN
if _CACHED_INSTANCE_DOMAIN is not None:
return _CACHED_INSTANCE_DOMAIN
kv_store = get_kv_store()
kv_store = get_kv_store(tenant_id=tenant_id)
try:
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
except KvKeyNotFoundError:
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
first_user = db_session.query(User).first()
if first_user:
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
@ -94,16 +94,16 @@ def optional_telemetry(
if DISABLE_TELEMETRY:
return
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
try:
def telemetry_logic() -> None:
try:
customer_uuid = (
_get_or_generate_customer_id_mt(
tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
)
_get_or_generate_customer_id_mt(tenant_id)
if MULTI_TENANT
else get_or_generate_uuid()
else get_or_generate_uuid(tenant_id)
)
payload = {
"data": data,
@ -115,7 +115,9 @@ def optional_telemetry(
"is_cloud": MULTI_TENANT,
}
if ENTERPRISE_EDITION_ENABLED:
payload["instance_domain"] = _get_or_generate_instance_domain()
payload["instance_domain"] = _get_or_generate_instance_domain(
tenant_id
)
requests.post(
_DANSWER_TELEMETRY_ENDPOINT,
headers={"Content-Type": "application/json"},