mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-11 05:19:52 +02:00
Pass in tenant_id to kv_store in monitoring job
This commit is contained in:
parent
cc4953b560
commit
1378364686
@ -2,7 +2,7 @@ from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
|
||||
|
||||
def get_kv_store() -> KeyValueStore:
|
||||
def get_kv_store(tenant_id: str | None = None) -> KeyValueStore:
|
||||
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
|
||||
# It's read from the global thread level variable
|
||||
return PgRedisKVStore()
|
||||
return PgRedisKVStore(tenant_id=tenant_id)
|
||||
|
@ -31,27 +31,27 @@ class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(
|
||||
self, redis_client: Redis | None = None, tenant_id: str | None = None
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
self.redis_client = redis_client
|
||||
else:
|
||||
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
self.redis_client = get_redis_client(tenant_id=self.tenant_id)
|
||||
|
||||
@contextmanager
|
||||
def get_session(self) -> Iterator[Session]:
|
||||
def _get_session(self) -> Iterator[Session]:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User must authenticate"
|
||||
)
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
if not is_valid_schema_name(self.tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
|
||||
yield session
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
@ -66,7 +66,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
with self.get_session() as session:
|
||||
with self._get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
if obj:
|
||||
obj.value = plain_val
|
||||
@ -88,7 +88,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with self.get_session() as session:
|
||||
with self._get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
if not obj:
|
||||
raise KvKeyNotFoundError
|
||||
@ -113,7 +113,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with self.get_session() as session:
|
||||
with self._get_session() as session:
|
||||
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
|
||||
if result == 0:
|
||||
raise KvKeyNotFoundError
|
||||
|
@ -212,7 +212,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# We cache this at the beginning so there is no delay in the first telemetry
|
||||
get_or_generate_uuid()
|
||||
get_or_generate_uuid(tenant_id=None)
|
||||
|
||||
# If we are multi-tenant, we need to only set up initial public tables
|
||||
with Session(engine) as db_session:
|
||||
|
@ -11,7 +11,7 @@ from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.constants import KV_CUSTOMER_UUID_KEY
|
||||
from onyx.configs.constants import KV_INSTANCE_DOMAIN_KEY
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.models import User
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@ -41,7 +41,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_X500, tenant_id))
|
||||
|
||||
|
||||
def get_or_generate_uuid(tenant_id: str | None = None) -> str:
|
||||
def get_or_generate_uuid(tenant_id: str | None) -> str:
|
||||
# TODO: split out the whole "instance UUID" generation logic into a separate
|
||||
# utility function. Telemetry should not be aware at all of how the UUID is
|
||||
# generated/stored.
|
||||
@ -52,7 +52,7 @@ def get_or_generate_uuid(tenant_id: str | None = None) -> str:
|
||||
if _CACHED_UUID is not None:
|
||||
return _CACHED_UUID
|
||||
|
||||
kv_store = get_kv_store()
|
||||
kv_store = get_kv_store(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
|
||||
@ -63,18 +63,18 @@ def get_or_generate_uuid(tenant_id: str | None = None) -> str:
|
||||
return _CACHED_UUID
|
||||
|
||||
|
||||
def _get_or_generate_instance_domain() -> str | None: #
|
||||
def _get_or_generate_instance_domain(tenant_id: str | None = None) -> str | None: #
|
||||
global _CACHED_INSTANCE_DOMAIN
|
||||
|
||||
if _CACHED_INSTANCE_DOMAIN is not None:
|
||||
return _CACHED_INSTANCE_DOMAIN
|
||||
|
||||
kv_store = get_kv_store()
|
||||
kv_store = get_kv_store(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
first_user = db_session.query(User).first()
|
||||
if first_user:
|
||||
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
|
||||
@ -94,16 +94,16 @@ def optional_telemetry(
|
||||
if DISABLE_TELEMETRY:
|
||||
return
|
||||
|
||||
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
try:
|
||||
|
||||
def telemetry_logic() -> None:
|
||||
try:
|
||||
customer_uuid = (
|
||||
_get_or_generate_customer_id_mt(
|
||||
tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
_get_or_generate_customer_id_mt(tenant_id)
|
||||
if MULTI_TENANT
|
||||
else get_or_generate_uuid()
|
||||
else get_or_generate_uuid(tenant_id)
|
||||
)
|
||||
payload = {
|
||||
"data": data,
|
||||
@ -115,7 +115,9 @@ def optional_telemetry(
|
||||
"is_cloud": MULTI_TENANT,
|
||||
}
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
payload["instance_domain"] = _get_or_generate_instance_domain()
|
||||
payload["instance_domain"] = _get_or_generate_instance_domain(
|
||||
tenant_id
|
||||
)
|
||||
requests.post(
|
||||
_DANSWER_TELEMETRY_ENDPOINT,
|
||||
headers={"Content-Type": "application/json"},
|
||||
|
Loading…
x
Reference in New Issue
Block a user