mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 19:38:19 +02:00
Proper tenant reset (#3015)
* add proper tenant reset * clear comment * minor formatting
This commit is contained in:
parent
add87fa1b4
commit
0b08bf4e3f
@ -322,11 +322,18 @@ async def get_async_session_with_tenant(
|
||||
def get_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
|
||||
"""
|
||||
Generate a database session bound to a connection with the appropriate tenant schema set.
|
||||
This preserves the tenant ID across the session and reverts to the previous tenant ID
|
||||
after the session is closed.
|
||||
"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Store the previous tenant ID
|
||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
tenant_id = previous_tenant_id
|
||||
else:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
@ -335,30 +342,35 @@ def get_session_with_tenant(
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
# Establish a raw connection
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection and set the search_path
|
||||
dbapi_connection = connection.connection
|
||||
try:
|
||||
# Establish a raw connection
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection and set the search_path
|
||||
dbapi_connection = connection.connection
|
||||
|
||||
# Set the search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Bind the session to the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
# Set the search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
yield session
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
finally:
|
||||
# Reset search_path to default after the session is used
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
cursor.close()
|
||||
|
||||
# Bind the session to the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Reset search_path to default after the session is used
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
finally:
|
||||
# Restore the previous tenant ID
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
|
||||
|
||||
|
||||
def set_search_path_on_checkout(
|
||||
|
@ -190,7 +190,6 @@ def bulk_invite_users(
|
||||
)
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
normalized_emails = []
|
||||
try:
|
||||
for email in emails:
|
||||
@ -206,6 +205,7 @@ def bulk_invite_users(
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
add_users_to_tenant(normalized_emails, tenant_id)
|
||||
|
||||
except IntegrityError as e:
|
||||
if isinstance(e.orig, UniqueViolation):
|
||||
raise HTTPException(
|
||||
@ -213,6 +213,8 @@ def bulk_invite_users(
|
||||
detail="User has already been invited to a Danswer organization",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")
|
||||
|
||||
initial_invited_users = get_invited_users()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user