Multi tenant tests (#3919)

* ensure fail on multi tenant successfully

* attempted fix

* udpate ingration tests

* minor update

* improve

* improve workflow

* fix migrations

* many more logs

* quick fix

* improve

* fix typo

* quick nit

* attempted fix

* very minor clean up
This commit is contained in:
pablonyx
2025-02-06 17:24:00 -08:00
committed by GitHub
parent bfa4fbd691
commit 48ac690a70
13 changed files with 509 additions and 90 deletions

View File

@@ -5,7 +5,6 @@ Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
@@ -20,13 +19,8 @@ down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
@@ -63,7 +57,6 @@ def upgrade() -> None:
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
@@ -71,15 +64,12 @@ def upgrade() -> None:
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
@@ -170,10 +160,9 @@ def upgrade() -> None:
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
logger.warning("tried to delete tokens in dynamic config but failed")
pass
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
@@ -190,8 +179,6 @@ def upgrade() -> None:
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
@@ -273,7 +260,7 @@ def downgrade() -> None:
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
logger.warning("Failed to save tokens back to KV store")
pass
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")

View File

@@ -64,6 +64,7 @@ async def _get_tenant_id_from_request(
try:
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if not token_data:

View File

@@ -24,6 +24,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
@@ -85,7 +86,8 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
await notify_control_plane(tenant_id, email, referral_source)
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)

View File

@@ -34,6 +34,7 @@ from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
@@ -286,7 +287,7 @@ def bulk_invite_users(
detail=f"Invalid email address: {email} - {str(e)}",
)
if MULTI_TENANT:
if MULTI_TENANT and not DEV_MODE:
try:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning", "add_users_to_tenant", None

View File

@@ -70,8 +70,8 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
# Set up application files
COPY ./onyx /app/onyx
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini
COPY ./pytest.ini /app/pytest.ini
COPY supervisord.conf /usr/etc/supervisord.conf

View File

@@ -24,35 +24,6 @@ def generate_auth_token() -> str:
class TenantManager:
@staticmethod
def create(
tenant_id: str | None = None,
initial_admin_email: str | None = None,
referral_source: str | None = None,
) -> dict[str, str]:
body = {
"tenant_id": tenant_id,
"initial_admin_email": initial_admin_email,
"referral_source": referral_source,
}
token = generate_auth_token()
headers = {
"Authorization": f"Bearer {token}",
"X-API-KEY": "",
"Content-Type": "application/json",
}
response = requests.post(
url=f"{API_SERVER_URL}/tenants/create",
json=body,
headers=headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def get_all_users(
user_performing_action: DATestUser | None = None,

View File

@@ -92,6 +92,7 @@ class UserManager:
# Set cookies in the headers
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
test_user.cookies = {"fastapiusersauth": session_cookie}
return test_user
@staticmethod
@@ -102,6 +103,7 @@ class UserManager:
response = requests.get(
url=f"{API_SERVER_URL}/me",
headers=user_to_verify.headers,
cookies=user_to_verify.cookies,
)
if user_to_verify.is_active is False:

View File

@@ -242,6 +242,18 @@ def reset_postgres_multitenant() -> None:
schema_name = schema[0]
cur.execute(f'DROP SCHEMA "{schema_name}" CASCADE')
# Drop tables in the public schema
cur.execute(
"""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
"""
)
public_tables = cur.fetchall()
for table in public_tables:
table_name = table[0]
cur.execute(f'DROP TABLE IF EXISTS public."{table_name}" CASCADE')
cur.close()
conn.close()

View File

@@ -44,6 +44,7 @@ class DATestUser(BaseModel):
headers: dict
role: UserRole
is_active: bool
cookies: dict = {}
class DATestPersonaLabel(BaseModel):

View File

@@ -4,7 +4,6 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.tenant import TenantManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestCCPair
@@ -13,25 +12,28 @@ from tests.integration.common_utils.test_models import DATestUser
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Create Tenant 1 and its Admin User
TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration")
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
assert UserManager.is_role(test_user1, UserRole.ADMIN)
# Creating an admin user (first user created is automatically an admin and also proviions the tenant
admin_user1: DATestUser = UserManager.create(
email="admin@onyx-test.com",
)
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
# Create Tenant 2 and its Admin User
TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration")
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
assert UserManager.is_role(test_user2, UserRole.ADMIN)
admin_user2: DATestUser = UserManager.create(
email="admin2@onyx-test.com",
)
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
# Create connectors for Tenant 1
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=test_user1,
user_performing_action=admin_user1,
)
api_key_1: DATestAPIKey = APIKeyManager.create(
user_performing_action=test_user1,
user_performing_action=admin_user1,
)
api_key_1.headers.update(test_user1.headers)
LLMProviderManager.create(user_performing_action=test_user1)
api_key_1.headers.update(admin_user1.headers)
LLMProviderManager.create(user_performing_action=admin_user1)
# Seed documents for Tenant 1
cc_pair_1.documents = []
@@ -49,13 +51,13 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Create connectors for Tenant 2
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=test_user2,
user_performing_action=admin_user2,
)
api_key_2: DATestAPIKey = APIKeyManager.create(
user_performing_action=test_user2,
user_performing_action=admin_user2,
)
api_key_2.headers.update(test_user2.headers)
LLMProviderManager.create(user_performing_action=test_user2)
api_key_2.headers.update(admin_user2.headers)
LLMProviderManager.create(user_performing_action=admin_user2)
# Seed documents for Tenant 2
cc_pair_2.documents = []
@@ -76,17 +78,17 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
# Create chat sessions for each user
chat_session1: DATestChatSession = ChatSessionManager.create(
user_performing_action=test_user1
user_performing_action=admin_user1
)
chat_session2: DATestChatSession = ChatSessionManager.create(
user_performing_action=test_user2
user_performing_action=admin_user2
)
# User 1 sends a message and gets a response
response1 = ChatSessionManager.send_message(
chat_session_id=chat_session1.id,
message="What is in Tenant 1's documents?",
user_performing_action=test_user1,
user_performing_action=admin_user1,
)
# Assert that the search tool was used
assert response1.tool_name == "run_search"
@@ -100,14 +102,16 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
), "Tenant 2 document IDs should not be in the response"
# Assert that the contents are correct
for doc in response1.tool_result or []:
assert doc["content"] == "Tenant 1 Document Content"
assert any(
doc["content"] == "Tenant 1 Document Content"
for doc in response1.tool_result or []
), "Tenant 1 Document Content not found in any document"
# User 2 sends a message and gets a response
response2 = ChatSessionManager.send_message(
chat_session_id=chat_session2.id,
message="What is in Tenant 2's documents?",
user_performing_action=test_user2,
user_performing_action=admin_user2,
)
# Assert that the search tool was used
assert response2.tool_name == "run_search"
@@ -119,15 +123,18 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
assert not response_doc_ids.intersection(
tenant1_doc_ids
), "Tenant 1 document IDs should not be in the response"
# Assert that the contents are correct
for doc in response2.tool_result or []:
assert doc["content"] == "Tenant 2 Document Content"
assert any(
doc["content"] == "Tenant 2 Document Content"
for doc in response2.tool_result or []
), "Tenant 2 Document Content not found in any document"
# User 1 tries to access Tenant 2's documents
response_cross = ChatSessionManager.send_message(
chat_session_id=chat_session1.id,
message="What is in Tenant 2's documents?",
user_performing_action=test_user1,
user_performing_action=admin_user1,
)
# Assert that the search tool was used
assert response_cross.tool_name == "run_search"
@@ -140,7 +147,7 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
response_cross2 = ChatSessionManager.send_message(
chat_session_id=chat_session2.id,
message="What is in Tenant 1's documents?",
user_performing_action=test_user2,
user_performing_action=admin_user2,
)
# Assert that the search tool was used
assert response_cross2.tool_name == "run_search"

View File

@@ -4,14 +4,12 @@ from onyx.db.models import UserRole
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.tenant import TenantManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
# Test flow from creating tenant to registering as a user
def test_tenant_creation(reset_multitenant: None) -> None:
TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration")
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
assert UserManager.is_role(test_user, UserRole.ADMIN)