mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-19 01:01:14 +02:00
rename context var
This commit is contained in:
parent
336f931c85
commit
622952e39b
@ -93,7 +93,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
@ -249,7 +249,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
@ -288,7 +288,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
current_tenant_id.reset(token)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
@ -342,7 +342,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
verify_email_in_whitelist(account_email, tenant_id)
|
||||
verify_email_domain(account_email)
|
||||
@ -432,7 +432,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user.oidc_expiry = None # type: ignore
|
||||
|
||||
if token:
|
||||
current_tenant_id.reset(token)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
|
@ -27,7 +27,7 @@ from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
@ -175,7 +175,7 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
token = current_tenant_id.set(self.tenant_id)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
|
||||
with get_session_with_tenant(self.tenant_id) as db_session:
|
||||
for file_path in self.file_locations:
|
||||
@ -199,7 +199,7 @@ class LocalFileConnector(LoadConnector):
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
current_tenant_id.reset(token)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -57,7 +57,7 @@ from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@ -364,7 +364,7 @@ def process_message(
|
||||
# Set the current tenant ID at the beginning for all DB calls within this thread
|
||||
if client.tenant_id:
|
||||
logger.info(f"Setting tenant ID to {client.tenant_id}")
|
||||
token = current_tenant_id.set(client.tenant_id)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_bot_config = get_slack_bot_config_for_channel(
|
||||
@ -413,7 +413,7 @@ def process_message(
|
||||
apologize_for_fail(details, client)
|
||||
finally:
|
||||
if client.tenant_id:
|
||||
current_tenant_id.reset(token)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
@ -511,11 +511,9 @@ if __name__ == "__main__":
|
||||
for tenant_id in tenant_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
token = current_tenant_id.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
|
||||
latest_slack_bot_tokens = fetch_tokens()
|
||||
current_tenant_id.reset(token)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
if (
|
||||
tenant_id not in slack_bot_tokens
|
||||
|
@ -39,7 +39,7 @@ from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.configs.constants import TENANT_ID_PREFIX
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
@ -260,12 +260,12 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
current_tenant_id.set(tenant_id)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
token = request.cookies.get("tenant_details")
|
||||
if not token:
|
||||
current_value = current_tenant_id.get()
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
# If no token is present, use the default schema or handle accordingly
|
||||
return current_value
|
||||
|
||||
@ -273,14 +273,14 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
||||
tenant_id = payload.get("tenant_id")
|
||||
if not tenant_id:
|
||||
return current_tenant_id.get()
|
||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
current_tenant_id.set(tenant_id)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
return tenant_id
|
||||
except jwt.InvalidTokenError:
|
||||
return current_tenant_id.get()
|
||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@ -291,7 +291,7 @@ async def get_async_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
if tenant_id is None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||
@ -323,9 +323,9 @@ def get_session_with_tenant(
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
else:
|
||||
current_tenant_id.set(tenant_id)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
@ -361,7 +361,7 @@ def get_session_with_tenant(
|
||||
def set_search_path_on_checkout(
|
||||
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
|
||||
) -> None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id and is_valid_schema_name(tenant_id):
|
||||
with dbapi_conn.cursor() as cursor:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
@ -371,14 +371,14 @@ def set_search_path_on_checkout(
|
||||
|
||||
|
||||
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||
|
||||
@ -395,7 +395,7 @@ def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Generate an async database session with the appropriate tenant schema set."""
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
|
@ -16,7 +16,7 @@ from danswer.key_value_store.interface import KeyValueStore
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
@ -28,7 +28,7 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(self) -> None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@contextmanager
|
||||
@ -36,7 +36,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User must authenticate"
|
||||
|
@ -10,7 +10,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.db.engine import current_tenant_id
|
||||
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.llm.interfaces import LLM
|
||||
@ -162,7 +162,7 @@ def retrieval_preprocessing(
|
||||
time_cutoff=time_filter or predicted_time_cutoff,
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=current_tenant_id.get() if MULTI_TENANT else None,
|
||||
tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
|
@ -25,7 +25,7 @@ from danswer.db.connector_credential_pair import (
|
||||
update_connector_credential_pair_from_id,
|
||||
)
|
||||
from danswer.db.document import get_document_counts_for_cc_pairs
|
||||
from danswer.db.engine import current_tenant_id
|
||||
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import AccessType
|
||||
@ -277,7 +277,7 @@ def prune_cc_pair(
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
primary_app, cc_pair, db_session, r, current_tenant_id.get()
|
||||
primary_app, cc_pair, db_session, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not tasks_created:
|
||||
raise HTTPException(
|
||||
@ -363,7 +363,9 @@ def sync_cc_pair(
|
||||
|
||||
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair_id, tenant_id=current_tenant_id.get()),
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
),
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
|
@ -38,7 +38,7 @@ from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.auth import get_total_users
|
||||
from danswer.db.engine import current_tenant_id
|
||||
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import DocumentSet__User
|
||||
@ -188,7 +188,7 @@ def bulk_invite_users(
|
||||
status_code=400, detail="Auth is disabled, cannot invite users"
|
||||
)
|
||||
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
normalized_emails = []
|
||||
try:
|
||||
@ -222,7 +222,9 @@ def bulk_invite_users(
|
||||
return number_of_invited_users
|
||||
try:
|
||||
logger.info("Registering tenant users")
|
||||
register_tenant_users(current_tenant_id.get(), get_total_users(db_session))
|
||||
register_tenant_users(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users(db_session)
|
||||
)
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
for email in all_emails:
|
||||
@ -250,13 +252,15 @@ def remove_invited_user(
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [user for user in user_emails if user != user_email.user_email]
|
||||
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
remove_users_from_tenant([user_email.user_email], tenant_id)
|
||||
number_of_invited_users = write_invited_users(remaining_users)
|
||||
|
||||
try:
|
||||
if MULTI_TENANT:
|
||||
register_tenant_users(current_tenant_id.get(), get_total_users(db_session))
|
||||
register_tenant_users(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users(db_session)
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Request to update number of seats taken in control plane failed. "
|
||||
|
@ -21,7 +21,7 @@ from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -41,7 +41,7 @@ def check_token_rate_limits(
|
||||
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
||||
"danswer.server.query_and_chat.token_limit", "_check_token_rate_limits"
|
||||
)
|
||||
return versioned_rate_limit_strategy(user, current_tenant_id.get())
|
||||
return versioned_rate_limit_strategy(user, CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
|
||||
|
||||
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
|
||||
|
@ -29,7 +29,7 @@ from ee.danswer.external_permissions.permission_sync import (
|
||||
run_external_group_permission_sync,
|
||||
)
|
||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -110,7 +110,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
to the queue"""
|
||||
token = None
|
||||
if MULTI_TENANT and tenant_id is not None:
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
@ -122,7 +122,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
),
|
||||
)
|
||||
if token is not None:
|
||||
current_tenant_id.reset(token)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
|
@ -11,7 +11,7 @@ from fastapi import Response
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.db.engine import is_valid_schema_name
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
else:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
current_tenant_id.set(tenant_id)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
logger.info(f"Middleware set current_tenant_id to: {tenant_id}")
|
||||
|
||||
response = await call_next(request)
|
||||
|
@ -24,7 +24,7 @@ from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
|
||||
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
|
||||
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
@ -55,7 +55,7 @@ def create_tenant(
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
run_alembic_migrations(tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@ -74,7 +74,7 @@ def create_tenant(
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
current_tenant_id.reset(token)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
@ -89,7 +89,7 @@ def gate_product(
|
||||
2) User's card has declined
|
||||
"""
|
||||
tenant_id = product_gating_request.tenant_id
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
settings.product_gating = product_gating_request.product_gating
|
||||
@ -100,7 +100,7 @@ def gate_product(
|
||||
create_notification(None, product_gating_request.notification, db_session)
|
||||
|
||||
if token is not None:
|
||||
current_tenant_id.reset(token)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@router.get("/billing-information", response_model=BillingInformation)
|
||||
@ -108,14 +108,16 @@ async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation:
|
||||
logger.info("Fetching billing information")
|
||||
return BillingInformation(**fetch_billing_information(current_tenant_id.get()))
|
||||
return BillingInformation(
|
||||
**fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
|
||||
try:
|
||||
# Fetch tenant_id and current tenant's information
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
|
@ -8,7 +8,7 @@ from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.danswer.server.tenants.access import generate_data_plane_token
|
||||
from shared_configs.configs import current_tenant_id
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
|
||||
@ -50,7 +50,7 @@ def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscr
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
tenant_id = current_tenant_id.get()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
|
@ -131,8 +131,10 @@ else:
|
||||
|
||||
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
|
||||
|
||||
current_tenant_id = contextvars.ContextVar(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar(
|
||||
|
||||
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user