Proper tenant reset (#3015)

* add proper tenant reset

* clear comment

* minor formatting
This commit is contained in:
pablodanswer 2024-10-31 12:45:35 -07:00 committed by GitHub
parent add87fa1b4
commit 0b08bf4e3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 24 deletions

View File

@ -322,11 +322,18 @@ async def get_async_session_with_tenant(
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
"""
Generate a database session bound to a connection with the appropriate tenant schema set.
This preserves the tenant ID across the session and reverts to the previous tenant ID
after the session is closed.
"""
engine = get_sqlalchemy_engine()
# Store the previous tenant ID
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id is None:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = previous_tenant_id
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@ -335,30 +342,35 @@ def get_session_with_tenant(
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
try:
# Establish a raw connection
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
# Set the search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
yield session
cursor.execute(f'SET search_path = "{tenant_id}"')
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
cursor.close()
# Bind the session to the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
finally:
# Restore the previous tenant ID
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
def set_search_path_on_checkout(

View File

@ -190,7 +190,6 @@ def bulk_invite_users(
)
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
normalized_emails = []
try:
for email in emails:
@ -206,6 +205,7 @@ def bulk_invite_users(
if MULTI_TENANT:
try:
add_users_to_tenant(normalized_emails, tenant_id)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise HTTPException(
@ -213,6 +213,8 @@ def bulk_invite_users(
detail="User has already been invited to a Danswer organization",
)
raise
except Exception as e:
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")
initial_invited_users = get_invited_users()