add some more multi tenancy

This commit is contained in:
pablodanswer
2024-10-19 16:32:10 -07:00
parent 802dc00f78
commit a2fd8d5e0a
9 changed files with 96 additions and 122 deletions

View File

@@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password" os.environ.get("POSTGRES_PASSWORD") or "password"
) )
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost" POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432" POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres" POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int( POSTGRES_API_SERVER_POOL_SIZE = int(

View File

@@ -232,6 +232,8 @@ def create_credential(
user: User | None, user: User | None,
db_session: Session, db_session: Session,
) -> Credential: ) -> Credential:
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
credential = Credential( credential = Credential(
credential_json=credential_data.credential_json, credential_json=credential_data.credential_json,
user_id=user.id if user else None, user_id=user.id if user else None,
@@ -241,7 +243,12 @@ def create_credential(
curator_public=credential_data.curator_public, curator_public=credential_data.curator_public,
) )
db_session.add(credential) db_session.add(credential)
db_session.flush() # This ensures the credential gets an ID # Query and print length of all credentials
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
db_session.flush() # This ensures the credential gets an IDcredentials
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
_relate_credential_to_user_groups__no_commit( _relate_credential_to_user_groups__no_commit(
db_session=db_session, db_session=db_session,
@@ -249,7 +256,7 @@ def create_credential(
user_group_ids=credential_data.groups, user_group_ids=credential_data.groups,
) )
db_session.commit() # db_session.commit()
return credential return credential

View File

@@ -291,45 +291,57 @@ async def get_async_session_with_tenant(
yield session yield session
class TenantSession(Session):
def __init__(self, *args, **kwargs):
self.tenant_id = kwargs.pop("tenant_id", None)
super().__init__(*args, **kwargs)
def __enter__(self):
super().__enter__()
if self.tenant_id:
self.execute(text(f'SET search_path TO "{self.tenant_id}"'))
return self
@contextmanager @contextmanager
def get_session_with_tenant( def get_session_with_tenant(
tenant_id: str | None = None, tenant_id: str | None = None,
) -> Generator[Session, None, None]: ) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set.""" """Generate a database session with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
if tenant_id is None: if tenant_id is None:
tenant_id = current_tenant_id.get() tenant_id = current_tenant_id.get()
if not is_valid_schema_name(tenant_id): if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID") raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Establish a raw connection without starting a transaction engine = get_sqlalchemy_engine()
with engine.connect() as connection: SessionLocal = sessionmaker(bind=engine, expire_on_commit=False, class_=Session)
# Access the raw DBAPI connection
dbapi_connection = connection.connection # Create a session
with SessionLocal() as session:
# Attach the event listener to set the search_path
@event.listens_for(session, "after_begin")
def _set_search_path(session, transaction, connection, **kw):
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# Execute SET search_path outside of any transaction
cursor = dbapi_connection.cursor()
try: try:
cursor.execute(f'SET search_path TO "{tenant_id}"') yield session
# Optionally verify the search_path was set correctly
cursor.execute("SHOW search_path")
cursor.fetchone()
finally: finally:
cursor.close() if MULTI_TENANT:
# Proceed to create a session using the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
# Reset search_path to default after the session is used # Reset search_path to default after the session is used
if MULTI_TENANT: session.execute(text('SET search_path TO "$user", public'))
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public') # Optionally, attach engine-level event listener
finally: def set_search_path_on_checkout(dbapi_connection, connection_record, connection_proxy):
cursor.close() tenant_id = current_tenant_id.get()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_connection.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
engine = get_sqlalchemy_engine()
event.listen(engine, "checkout", set_search_path_on_checkout)
def get_session_generator_with_tenant( def get_session_generator_with_tenant(

View File

@@ -1,16 +1,9 @@
import functools
import threading import threading
from typing import Any
from typing import cast
from typing import Optional from typing import Optional
from typing import Union
import redis import redis
from redis.client import Redis from redis.client import Redis
from redis.typing import AbsExpiryT
from redis.typing import EncodableT
from redis.typing import ExpiryT
from redis.typing import KeyT
from redis.typing import ResponseT
from danswer.configs.app_configs import REDIS_DB_NUMBER from danswer.configs.app_configs import REDIS_DB_NUMBER
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
@@ -24,15 +17,12 @@ from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
# TODO: enforce typing strictly
class TenantRedis(redis.Redis): class TenantRedis(redis.Redis):
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any): def __init__(self, tenant_id: str, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.tenant_id = tenant_id self.tenant_id = tenant_id
def _prefixed( def _prefixed(self, key):
self, key: Union[str, bytes, memoryview]
) -> Union[str, bytes, memoryview]:
prefix = f"{self.tenant_id}:" prefix = f"{self.tenant_id}:"
if isinstance(key, str): if isinstance(key, str):
return prefix + key return prefix + key
@@ -43,87 +33,32 @@ class TenantRedis(redis.Redis):
else: else:
raise TypeError(f"Unsupported key type: {type(key)}") raise TypeError(f"Unsupported key type: {type(key)}")
def lock( def _prefix_method(self, method):
self, @functools.wraps(method)
name: str, def wrapper(*args, **kwargs):
timeout: Optional[float] = None, if "name" in kwargs:
sleep: float = 0.1, kwargs["name"] = self._prefixed(kwargs["name"])
blocking: bool = True, elif len(args) > 0:
blocking_timeout: Optional[float] = None, args = (self._prefixed(args[0]),) + args[1:]
lock_class: Union[None, Any] = None, return method(*args, **kwargs)
thread_local: bool = True,
) -> Any:
prefixed_name = cast(str, self._prefixed(name))
return super().lock(
prefixed_name,
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
lock_class=lock_class,
thread_local=thread_local,
)
def incrby(self, name: KeyT, amount: int = 1) -> ResponseT: return wrapper
"""
Increments the value of ``key`` by ``amount``. If no key exists,
the value will be initialized as ``amount``
For more information see https://redis.io/commands/incrby def __getattribute__(self, item):
""" original_attr = super().__getattribute__(item)
prefixed_name = self._prefixed(name) methods_to_wrap = [
return super().incrby(prefixed_name, amount) "get",
"set",
def set( "delete",
self, "exists",
name: KeyT, "incrby",
value: EncodableT, "hset",
ex: Union[ExpiryT, None] = None, "hget",
px: Union[ExpiryT, None] = None, "getset",
nx: bool = False, ] # Add all methods that need prefixing
xx: bool = False, if item in methods_to_wrap and callable(original_attr):
keepttl: bool = False, return self._prefix_method(original_attr)
get: bool = False, return original_attr
exat: Union[AbsExpiryT, None] = None,
pxat: Union[AbsExpiryT, None] = None,
) -> ResponseT:
prefixed_name = self._prefixed(name)
return super().set(
prefixed_name,
value,
ex=ex,
px=px,
nx=nx,
xx=xx,
keepttl=keepttl,
get=get,
exat=exat,
pxat=pxat,
)
def get(self, name: KeyT) -> ResponseT:
prefixed_name = self._prefixed(name)
return super().get(prefixed_name)
def delete(self, *names: KeyT) -> ResponseT:
prefixed_names = [self._prefixed(name) for name in names]
return super().delete(*prefixed_names)
def exists(self, *names: KeyT) -> ResponseT:
prefixed_names = [self._prefixed(name) for name in names]
return super().exists(*prefixed_names)
# def expire(self, name: str, time: int, **kwargs: Any) -> Any:
# prefixed_name = self._prefixed(name)
# return super().expire(prefixed_name, time, **kwargs)
# def ttl(self, name: str, **kwargs: Any) -> Any:
# prefixed_name = self._prefixed(name)
# return super().ttl(prefixed_name, **kwargs)
# def type(self, name: str, **kwargs: Any) -> Any:
# prefixed_name = self._prefixed(name)
# return super().type(prefixed_name, **kwargs)
class RedisPool: class RedisPool:

View File

@@ -669,31 +669,45 @@ def create_connector_with_mock_credential(
user: User = Depends(current_curator_or_admin_user), user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> StatusResponse: ) -> StatusResponse:
print("Starting create_connector_with_mock_credential function")
if user and user.role != UserRole.ADMIN: if user and user.role != UserRole.ADMIN:
print(f"User role: {user.role}")
if connector_data.is_public: if connector_data.is_public:
print("Non-admin user attempting to create public credential")
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="User does not have permission to create public credentials", detail="User does not have permission to create public credentials",
) )
if not connector_data.groups: if not connector_data.groups:
print("Curator attempting to create connector without specifying groups")
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
detail="Curators must specify 1+ groups", detail="Curators must specify 1+ groups",
) )
try: try:
print(f"Validating connector: {connector_data.source}")
_validate_connector_allowed(connector_data.source) _validate_connector_allowed(connector_data.source)
print("Creating connector")
connector_response = create_connector( connector_response = create_connector(
db_session=db_session, connector_data=connector_data db_session=db_session, connector_data=connector_data
) )
print(f"Connector created with ID: {connector_response.id}")
print("Creating mock credential")
mock_credential = CredentialBase( mock_credential = CredentialBase(
credential_json={}, admin_public=True, source=connector_data.source credential_json={}, admin_public=True, source=connector_data.source
) )
credential = create_credential( credential = create_credential(
mock_credential, user=user, db_session=db_session mock_credential, user=user, db_session=db_session
) )
print(f"Mock credential created with ID: {credential.id}")
access_type = ( access_type = (
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
) )
print(f"Access type: {access_type}")
print("Adding credential to connector")
response = add_credential_to_connector( response = add_credential_to_connector(
db_session=db_session, db_session=db_session,
user=user, user=user,
@@ -703,9 +717,11 @@ def create_connector_with_mock_credential(
cc_pair_name=connector_data.name, cc_pair_name=connector_data.name,
groups=connector_data.groups, groups=connector_data.groups,
) )
print("Credential added to connector successfully")
return response return response
except ValueError as e: except ValueError as e:
print(f"ValueError occurred: {str(e)}")
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))

View File

@@ -137,6 +137,10 @@ def create_credential_from_model(
target_group_ids=credential_info.groups, target_group_ids=credential_info.groups,
object_is_public=credential_info.curator_public, object_is_public=credential_info.curator_public,
) )
from danswer.db.models import Credential
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
credential = create_credential(credential_info, user, db_session) credential = create_credential(credential_info, user, db_session)
return ObjectCreationIdResponse( return ObjectCreationIdResponse(

View File

@@ -313,7 +313,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports: ports:
- "5432:5432" - "5433:5432"
volumes: volumes:
- db_volume:/var/lib/postgresql/data - db_volume:/var/lib/postgresql/data

View File

@@ -313,7 +313,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports: ports:
- "5432:5432" - "5433:5432"
volumes: volumes:
- db_volume:/var/lib/postgresql/data - db_volume:/var/lib/postgresql/data

View File

@@ -157,7 +157,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres} - POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password} - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports: ports:
- "5432" - "5433"
volumes: volumes:
- db_volume:/var/lib/postgresql/data - db_volume:/var/lib/postgresql/data