mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-18 11:34:12 +02:00
Multi tenant tests (#3919)
* ensure fail on multi tenant successfully * attempted fix * udpate ingration tests * minor update * improve * improve workflow * fix migrations * many more logs * quick fix * improve * fix typo * quick nit * attempted fix * very minor clean up
This commit is contained in:
@@ -5,7 +5,6 @@ Revises: 47e5bef3a1d7
|
||||
Create Date: 2024-11-06 13:15:53.302644
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
@@ -20,13 +19,8 @@ down_revision = "47e5bef3a1d7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
logger.info(f"{revision}: create_table: slack_bot")
|
||||
# Create new slack_bot table
|
||||
op.create_table(
|
||||
"slack_bot",
|
||||
@@ -63,7 +57,6 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
# Handle existing Slack bot tokens first
|
||||
logger.info(f"{revision}: Checking for existing Slack bot.")
|
||||
bot_token = None
|
||||
app_token = None
|
||||
first_row_id = None
|
||||
@@ -71,15 +64,12 @@ def upgrade() -> None:
|
||||
try:
|
||||
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
|
||||
except Exception:
|
||||
logger.warning("No existing Slack bot tokens found.")
|
||||
tokens = {}
|
||||
|
||||
bot_token = tokens.get("bot_token")
|
||||
app_token = tokens.get("app_token")
|
||||
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Found bot and app tokens.")
|
||||
|
||||
session = Session(bind=op.get_bind())
|
||||
new_slack_bot = SlackBot(
|
||||
name="Slack Bot (Migrated)",
|
||||
@@ -170,10 +160,9 @@ def upgrade() -> None:
|
||||
# Clean up old tokens if they existed
|
||||
try:
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Removing old bot and app tokens.")
|
||||
get_kv_store().delete("slack_bot_tokens_config_key")
|
||||
except Exception:
|
||||
logger.warning("tried to delete tokens in dynamic config but failed")
|
||||
pass
|
||||
# Rename the table
|
||||
op.rename_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
@@ -190,8 +179,6 @@ def upgrade() -> None:
|
||||
# Drop the table with CASCADE to handle dependent objects
|
||||
op.execute("DROP TABLE slack_bot_config CASCADE")
|
||||
|
||||
logger.info(f"{revision}: Migration complete.")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Recreate the old slack_bot_config table
|
||||
@@ -273,7 +260,7 @@ def downgrade() -> None:
|
||||
}
|
||||
get_kv_store().store("slack_bot_tokens_config_key", tokens)
|
||||
except Exception:
|
||||
logger.warning("Failed to save tokens back to KV store")
|
||||
pass
|
||||
|
||||
# Drop the new tables in reverse order
|
||||
op.drop_table("slack_channel_config")
|
||||
|
@@ -64,6 +64,7 @@ async def _get_tenant_id_from_request(
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if not token_data:
|
||||
|
@@ -24,6 +24,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
@@ -85,7 +86,8 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
|
@@ -34,6 +34,7 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
@@ -286,7 +287,7 @@ def bulk_invite_users(
|
||||
detail=f"Invalid email address: {email} - {str(e)}",
|
||||
)
|
||||
|
||||
if MULTI_TENANT:
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
try:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
|
@@ -70,8 +70,8 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
# Set up application files
|
||||
COPY ./onyx /app/onyx
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY ./pytest.ini /app/pytest.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
@@ -24,35 +24,6 @@ def generate_auth_token() -> str:
|
||||
|
||||
|
||||
class TenantManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
tenant_id: str | None = None,
|
||||
initial_admin_email: str | None = None,
|
||||
referral_source: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
body = {
|
||||
"tenant_id": tenant_id,
|
||||
"initial_admin_email": initial_admin_email,
|
||||
"referral_source": referral_source,
|
||||
}
|
||||
|
||||
token = generate_auth_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-API-KEY": "",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/tenants/create",
|
||||
json=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_all_users(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
|
@@ -92,6 +92,7 @@ class UserManager:
|
||||
|
||||
# Set cookies in the headers
|
||||
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
|
||||
test_user.cookies = {"fastapiusersauth": session_cookie}
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
@@ -102,6 +103,7 @@ class UserManager:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me",
|
||||
headers=user_to_verify.headers,
|
||||
cookies=user_to_verify.cookies,
|
||||
)
|
||||
|
||||
if user_to_verify.is_active is False:
|
||||
|
@@ -242,6 +242,18 @@ def reset_postgres_multitenant() -> None:
|
||||
schema_name = schema[0]
|
||||
cur.execute(f'DROP SCHEMA "{schema_name}" CASCADE')
|
||||
|
||||
# Drop tables in the public schema
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT tablename FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
"""
|
||||
)
|
||||
public_tables = cur.fetchall()
|
||||
for table in public_tables:
|
||||
table_name = table[0]
|
||||
cur.execute(f'DROP TABLE IF EXISTS public."{table_name}" CASCADE')
|
||||
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
@@ -44,6 +44,7 @@ class DATestUser(BaseModel):
|
||||
headers: dict
|
||||
role: UserRole
|
||||
is_active: bool
|
||||
cookies: dict = {}
|
||||
|
||||
|
||||
class DATestPersonaLabel(BaseModel):
|
||||
|
@@ -4,7 +4,6 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.tenant import TenantManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
@@ -13,25 +12,28 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
# Create Tenant 1 and its Admin User
|
||||
TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration")
|
||||
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
|
||||
assert UserManager.is_role(test_user1, UserRole.ADMIN)
|
||||
# Creating an admin user (first user created is automatically an admin and also proviions the tenant
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email="admin@onyx-test.com",
|
||||
)
|
||||
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
|
||||
# Create Tenant 2 and its Admin User
|
||||
TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration")
|
||||
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
|
||||
assert UserManager.is_role(test_user2, UserRole.ADMIN)
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email="admin2@onyx-test.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
|
||||
|
||||
# Create connectors for Tenant 1
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=test_user1,
|
||||
user_performing_action=admin_user1,
|
||||
)
|
||||
api_key_1: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=test_user1,
|
||||
user_performing_action=admin_user1,
|
||||
)
|
||||
api_key_1.headers.update(test_user1.headers)
|
||||
LLMProviderManager.create(user_performing_action=test_user1)
|
||||
api_key_1.headers.update(admin_user1.headers)
|
||||
LLMProviderManager.create(user_performing_action=admin_user1)
|
||||
|
||||
# Seed documents for Tenant 1
|
||||
cc_pair_1.documents = []
|
||||
@@ -49,13 +51,13 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
|
||||
# Create connectors for Tenant 2
|
||||
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=test_user2,
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
api_key_2: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=test_user2,
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
api_key_2.headers.update(test_user2.headers)
|
||||
LLMProviderManager.create(user_performing_action=test_user2)
|
||||
api_key_2.headers.update(admin_user2.headers)
|
||||
LLMProviderManager.create(user_performing_action=admin_user2)
|
||||
|
||||
# Seed documents for Tenant 2
|
||||
cc_pair_2.documents = []
|
||||
@@ -76,17 +78,17 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
|
||||
# Create chat sessions for each user
|
||||
chat_session1: DATestChatSession = ChatSessionManager.create(
|
||||
user_performing_action=test_user1
|
||||
user_performing_action=admin_user1
|
||||
)
|
||||
chat_session2: DATestChatSession = ChatSessionManager.create(
|
||||
user_performing_action=test_user2
|
||||
user_performing_action=admin_user2
|
||||
)
|
||||
|
||||
# User 1 sends a message and gets a response
|
||||
response1 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session1.id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=test_user1,
|
||||
user_performing_action=admin_user1,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response1.tool_name == "run_search"
|
||||
@@ -100,14 +102,16 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
), "Tenant 2 document IDs should not be in the response"
|
||||
|
||||
# Assert that the contents are correct
|
||||
for doc in response1.tool_result or []:
|
||||
assert doc["content"] == "Tenant 1 Document Content"
|
||||
assert any(
|
||||
doc["content"] == "Tenant 1 Document Content"
|
||||
for doc in response1.tool_result or []
|
||||
), "Tenant 1 Document Content not found in any document"
|
||||
|
||||
# User 2 sends a message and gets a response
|
||||
response2 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session2.id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=test_user2,
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response2.tool_name == "run_search"
|
||||
@@ -119,15 +123,18 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
assert not response_doc_ids.intersection(
|
||||
tenant1_doc_ids
|
||||
), "Tenant 1 document IDs should not be in the response"
|
||||
|
||||
# Assert that the contents are correct
|
||||
for doc in response2.tool_result or []:
|
||||
assert doc["content"] == "Tenant 2 Document Content"
|
||||
assert any(
|
||||
doc["content"] == "Tenant 2 Document Content"
|
||||
for doc in response2.tool_result or []
|
||||
), "Tenant 2 Document Content not found in any document"
|
||||
|
||||
# User 1 tries to access Tenant 2's documents
|
||||
response_cross = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session1.id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=test_user1,
|
||||
user_performing_action=admin_user1,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response_cross.tool_name == "run_search"
|
||||
@@ -140,7 +147,7 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
response_cross2 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session2.id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=test_user2,
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response_cross2.tool_name == "run_search"
|
||||
|
@@ -4,14 +4,12 @@ from onyx.db.models import UserRole
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.tenant import TenantManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
# Test flow from creating tenant to registering as a user
|
||||
def test_tenant_creation(reset_multitenant: None) -> None:
|
||||
TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration")
|
||||
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||
|
||||
assert UserManager.is_role(test_user, UserRole.ADMIN)
|
||||
|
Reference in New Issue
Block a user