add some more multi tenancy

This commit is contained in:
pablodanswer 2024-10-19 16:32:10 -07:00
parent 802dc00f78
commit a2fd8d5e0a
9 changed files with 96 additions and 122 deletions

View File

@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
POSTGRES_API_SERVER_POOL_SIZE = int(

View File

@ -232,6 +232,8 @@ def create_credential(
user: User | None,
db_session: Session,
) -> Credential:
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
credential = Credential(
credential_json=credential_data.credential_json,
user_id=user.id if user else None,
@ -241,7 +243,12 @@ def create_credential(
curator_public=credential_data.curator_public,
)
db_session.add(credential)
db_session.flush() # This ensures the credential gets an ID
# Query and print length of all credentials
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
db_session.flush() # This ensures the credential gets an IDcredentials
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
_relate_credential_to_user_groups__no_commit(
db_session=db_session,
@ -249,7 +256,7 @@ def create_credential(
user_group_ids=credential_data.groups,
)
db_session.commit()
# db_session.commit()
return credential

View File

@ -291,45 +291,57 @@ async def get_async_session_with_tenant(
yield session
class TenantSession(Session):
def __init__(self, *args, **kwargs):
self.tenant_id = kwargs.pop("tenant_id", None)
super().__init__(*args, **kwargs)
def __enter__(self):
super().__enter__()
if self.tenant_id:
self.execute(text(f'SET search_path TO "{self.tenant_id}"'))
return self
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
if tenant_id is None:
tenant_id = current_tenant_id.get()
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Establish a raw connection without starting a transaction
with engine.connect() as connection:
# Access the raw DBAPI connection
dbapi_connection = connection.connection
engine = get_sqlalchemy_engine()
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False, class_=Session)
# Create a session
with SessionLocal() as session:
# Attach the event listener to set the search_path
@event.listens_for(session, "after_begin")
def _set_search_path(session, transaction, connection, **kw):
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# Execute SET search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path TO "{tenant_id}"')
# Optionally verify the search_path was set correctly
cursor.execute("SHOW search_path")
cursor.fetchone()
yield session
finally:
cursor.close()
# Proceed to create a session using the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
if MULTI_TENANT:
# Reset search_path to default after the session is used
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
session.execute(text('SET search_path TO "$user", public'))
# Optionally, attach engine-level event listener
def set_search_path_on_checkout(dbapi_connection, connection_record, connection_proxy):
tenant_id = current_tenant_id.get()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_connection.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
engine = get_sqlalchemy_engine()
event.listen(engine, "checkout", set_search_path_on_checkout)
def get_session_generator_with_tenant(

View File

@ -1,16 +1,9 @@
import functools
import threading
from typing import Any
from typing import cast
from typing import Optional
from typing import Union
import redis
from redis.client import Redis
from redis.typing import AbsExpiryT
from redis.typing import EncodableT
from redis.typing import ExpiryT
from redis.typing import KeyT
from redis.typing import ResponseT
from danswer.configs.app_configs import REDIS_DB_NUMBER
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
@ -24,15 +17,12 @@ from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
# TODO: enforce typing strictly
class TenantRedis(redis.Redis):
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any):
def __init__(self, tenant_id: str, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tenant_id = tenant_id
def _prefixed(
self, key: Union[str, bytes, memoryview]
) -> Union[str, bytes, memoryview]:
def _prefixed(self, key):
prefix = f"{self.tenant_id}:"
if isinstance(key, str):
return prefix + key
@ -43,87 +33,32 @@ class TenantRedis(redis.Redis):
else:
raise TypeError(f"Unsupported key type: {type(key)}")
def lock(
self,
name: str,
timeout: Optional[float] = None,
sleep: float = 0.1,
blocking: bool = True,
blocking_timeout: Optional[float] = None,
lock_class: Union[None, Any] = None,
thread_local: bool = True,
) -> Any:
prefixed_name = cast(str, self._prefixed(name))
return super().lock(
prefixed_name,
timeout=timeout,
sleep=sleep,
blocking=blocking,
blocking_timeout=blocking_timeout,
lock_class=lock_class,
thread_local=thread_local,
)
def _prefix_method(self, method):
@functools.wraps(method)
def wrapper(*args, **kwargs):
if "name" in kwargs:
kwargs["name"] = self._prefixed(kwargs["name"])
elif len(args) > 0:
args = (self._prefixed(args[0]),) + args[1:]
return method(*args, **kwargs)
def incrby(self, name: KeyT, amount: int = 1) -> ResponseT:
"""
Increments the value of ``key`` by ``amount``. If no key exists,
the value will be initialized as ``amount``
return wrapper
For more information see https://redis.io/commands/incrby
"""
prefixed_name = self._prefixed(name)
return super().incrby(prefixed_name, amount)
def set(
self,
name: KeyT,
value: EncodableT,
ex: Union[ExpiryT, None] = None,
px: Union[ExpiryT, None] = None,
nx: bool = False,
xx: bool = False,
keepttl: bool = False,
get: bool = False,
exat: Union[AbsExpiryT, None] = None,
pxat: Union[AbsExpiryT, None] = None,
) -> ResponseT:
prefixed_name = self._prefixed(name)
return super().set(
prefixed_name,
value,
ex=ex,
px=px,
nx=nx,
xx=xx,
keepttl=keepttl,
get=get,
exat=exat,
pxat=pxat,
)
def get(self, name: KeyT) -> ResponseT:
prefixed_name = self._prefixed(name)
return super().get(prefixed_name)
def delete(self, *names: KeyT) -> ResponseT:
prefixed_names = [self._prefixed(name) for name in names]
return super().delete(*prefixed_names)
def exists(self, *names: KeyT) -> ResponseT:
prefixed_names = [self._prefixed(name) for name in names]
return super().exists(*prefixed_names)
# def expire(self, name: str, time: int, **kwargs: Any) -> Any:
# prefixed_name = self._prefixed(name)
# return super().expire(prefixed_name, time, **kwargs)
# def ttl(self, name: str, **kwargs: Any) -> Any:
# prefixed_name = self._prefixed(name)
# return super().ttl(prefixed_name, **kwargs)
# def type(self, name: str, **kwargs: Any) -> Any:
# prefixed_name = self._prefixed(name)
# return super().type(prefixed_name, **kwargs)
def __getattribute__(self, item):
original_attr = super().__getattribute__(item)
methods_to_wrap = [
"get",
"set",
"delete",
"exists",
"incrby",
"hset",
"hget",
"getset",
] # Add all methods that need prefixing
if item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)
return original_attr
class RedisPool:

View File

@ -669,31 +669,45 @@ def create_connector_with_mock_credential(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse:
print("Starting create_connector_with_mock_credential function")
if user and user.role != UserRole.ADMIN:
print(f"User role: {user.role}")
if connector_data.is_public:
print("Non-admin user attempting to create public credential")
raise HTTPException(
status_code=401,
detail="User does not have permission to create public credentials",
)
if not connector_data.groups:
print("Curator attempting to create connector without specifying groups")
raise HTTPException(
status_code=401,
detail="Curators must specify 1+ groups",
)
try:
print(f"Validating connector: {connector_data.source}")
_validate_connector_allowed(connector_data.source)
print("Creating connector")
connector_response = create_connector(
db_session=db_session, connector_data=connector_data
)
print(f"Connector created with ID: {connector_response.id}")
print("Creating mock credential")
mock_credential = CredentialBase(
credential_json={}, admin_public=True, source=connector_data.source
)
credential = create_credential(
mock_credential, user=user, db_session=db_session
)
print(f"Mock credential created with ID: {credential.id}")
access_type = (
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
)
print(f"Access type: {access_type}")
print("Adding credential to connector")
response = add_credential_to_connector(
db_session=db_session,
user=user,
@ -703,9 +717,11 @@ def create_connector_with_mock_credential(
cc_pair_name=connector_data.name,
groups=connector_data.groups,
)
print("Credential added to connector successfully")
return response
except ValueError as e:
print(f"ValueError occurred: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))

View File

@ -137,6 +137,10 @@ def create_credential_from_model(
target_group_ids=credential_info.groups,
object_is_public=credential_info.curator_public,
)
from danswer.db.models import Credential
all_credentials = db_session.query(Credential).all()
print(f"Total number of credentials: {len(all_credentials)}")
credential = create_credential(credential_info, user, db_session)
return ObjectCreationIdResponse(

View File

@ -313,7 +313,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@ -313,7 +313,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@ -157,7 +157,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432"
- "5433"
volumes:
- db_volume:/var/lib/postgresql/data