mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-29 05:15:12 +02:00
add some more multi tenancy
This commit is contained in:
@@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
|||||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||||
)
|
)
|
||||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||||
|
|
||||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||||
|
@@ -232,6 +232,8 @@ def create_credential(
|
|||||||
user: User | None,
|
user: User | None,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
) -> Credential:
|
) -> Credential:
|
||||||
|
all_credentials = db_session.query(Credential).all()
|
||||||
|
print(f"Total number of credentials: {len(all_credentials)}")
|
||||||
credential = Credential(
|
credential = Credential(
|
||||||
credential_json=credential_data.credential_json,
|
credential_json=credential_data.credential_json,
|
||||||
user_id=user.id if user else None,
|
user_id=user.id if user else None,
|
||||||
@@ -241,7 +243,12 @@ def create_credential(
|
|||||||
curator_public=credential_data.curator_public,
|
curator_public=credential_data.curator_public,
|
||||||
)
|
)
|
||||||
db_session.add(credential)
|
db_session.add(credential)
|
||||||
db_session.flush() # This ensures the credential gets an ID
|
# Query and print length of all credentials
|
||||||
|
all_credentials = db_session.query(Credential).all()
|
||||||
|
print(f"Total number of credentials: {len(all_credentials)}")
|
||||||
|
db_session.flush() # This ensures the credential gets an IDcredentials
|
||||||
|
all_credentials = db_session.query(Credential).all()
|
||||||
|
print(f"Total number of credentials: {len(all_credentials)}")
|
||||||
|
|
||||||
_relate_credential_to_user_groups__no_commit(
|
_relate_credential_to_user_groups__no_commit(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@@ -249,7 +256,7 @@ def create_credential(
|
|||||||
user_group_ids=credential_data.groups,
|
user_group_ids=credential_data.groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
db_session.commit()
|
# db_session.commit()
|
||||||
|
|
||||||
return credential
|
return credential
|
||||||
|
|
||||||
|
@@ -291,45 +291,57 @@ async def get_async_session_with_tenant(
|
|||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
class TenantSession(Session):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.tenant_id = kwargs.pop("tenant_id", None)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
super().__enter__()
|
||||||
|
if self.tenant_id:
|
||||||
|
self.execute(text(f'SET search_path TO "{self.tenant_id}"'))
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_session_with_tenant(
|
def get_session_with_tenant(
|
||||||
tenant_id: str | None = None,
|
tenant_id: str | None = None,
|
||||||
) -> Generator[Session, None, None]:
|
) -> Generator[Session, None, None]:
|
||||||
"""Generate a database session with the appropriate tenant schema set."""
|
"""Generate a database session with the appropriate tenant schema set."""
|
||||||
engine = get_sqlalchemy_engine()
|
|
||||||
if tenant_id is None:
|
if tenant_id is None:
|
||||||
tenant_id = current_tenant_id.get()
|
tenant_id = current_tenant_id.get()
|
||||||
|
|
||||||
if not is_valid_schema_name(tenant_id):
|
if not is_valid_schema_name(tenant_id):
|
||||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||||
|
|
||||||
# Establish a raw connection without starting a transaction
|
engine = get_sqlalchemy_engine()
|
||||||
with engine.connect() as connection:
|
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False, class_=Session)
|
||||||
# Access the raw DBAPI connection
|
|
||||||
dbapi_connection = connection.connection
|
# Create a session
|
||||||
|
with SessionLocal() as session:
|
||||||
|
# Attach the event listener to set the search_path
|
||||||
|
@event.listens_for(session, "after_begin")
|
||||||
|
def _set_search_path(session, transaction, connection, **kw):
|
||||||
|
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
|
||||||
|
|
||||||
# Execute SET search_path outside of any transaction
|
|
||||||
cursor = dbapi_connection.cursor()
|
|
||||||
try:
|
try:
|
||||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
yield session
|
||||||
# Optionally verify the search_path was set correctly
|
|
||||||
cursor.execute("SHOW search_path")
|
|
||||||
cursor.fetchone()
|
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
if MULTI_TENANT:
|
||||||
|
|
||||||
# Proceed to create a session using the connection
|
|
||||||
with Session(bind=connection, expire_on_commit=False) as session:
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
finally:
|
|
||||||
# Reset search_path to default after the session is used
|
# Reset search_path to default after the session is used
|
||||||
if MULTI_TENANT:
|
session.execute(text('SET search_path TO "$user", public'))
|
||||||
cursor = dbapi_connection.cursor()
|
|
||||||
try:
|
|
||||||
cursor.execute('SET search_path TO "$user", public')
|
# Optionally, attach engine-level event listener
|
||||||
finally:
|
def set_search_path_on_checkout(dbapi_connection, connection_record, connection_proxy):
|
||||||
cursor.close()
|
tenant_id = current_tenant_id.get()
|
||||||
|
if tenant_id and is_valid_schema_name(tenant_id):
|
||||||
|
with dbapi_connection.cursor() as cursor:
|
||||||
|
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||||
|
|
||||||
|
|
||||||
|
engine = get_sqlalchemy_engine()
|
||||||
|
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||||
|
|
||||||
|
|
||||||
def get_session_generator_with_tenant(
|
def get_session_generator_with_tenant(
|
||||||
|
@@ -1,16 +1,9 @@
|
|||||||
|
import functools
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
|
||||||
from typing import cast
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
from redis.client import Redis
|
from redis.client import Redis
|
||||||
from redis.typing import AbsExpiryT
|
|
||||||
from redis.typing import EncodableT
|
|
||||||
from redis.typing import ExpiryT
|
|
||||||
from redis.typing import KeyT
|
|
||||||
from redis.typing import ResponseT
|
|
||||||
|
|
||||||
from danswer.configs.app_configs import REDIS_DB_NUMBER
|
from danswer.configs.app_configs import REDIS_DB_NUMBER
|
||||||
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||||
@@ -24,15 +17,12 @@ from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
|||||||
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||||
|
|
||||||
|
|
||||||
# TODO: enforce typing strictly
|
|
||||||
class TenantRedis(redis.Redis):
|
class TenantRedis(redis.Redis):
|
||||||
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any):
|
def __init__(self, tenant_id: str, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
|
|
||||||
def _prefixed(
|
def _prefixed(self, key):
|
||||||
self, key: Union[str, bytes, memoryview]
|
|
||||||
) -> Union[str, bytes, memoryview]:
|
|
||||||
prefix = f"{self.tenant_id}:"
|
prefix = f"{self.tenant_id}:"
|
||||||
if isinstance(key, str):
|
if isinstance(key, str):
|
||||||
return prefix + key
|
return prefix + key
|
||||||
@@ -43,87 +33,32 @@ class TenantRedis(redis.Redis):
|
|||||||
else:
|
else:
|
||||||
raise TypeError(f"Unsupported key type: {type(key)}")
|
raise TypeError(f"Unsupported key type: {type(key)}")
|
||||||
|
|
||||||
def lock(
|
def _prefix_method(self, method):
|
||||||
self,
|
@functools.wraps(method)
|
||||||
name: str,
|
def wrapper(*args, **kwargs):
|
||||||
timeout: Optional[float] = None,
|
if "name" in kwargs:
|
||||||
sleep: float = 0.1,
|
kwargs["name"] = self._prefixed(kwargs["name"])
|
||||||
blocking: bool = True,
|
elif len(args) > 0:
|
||||||
blocking_timeout: Optional[float] = None,
|
args = (self._prefixed(args[0]),) + args[1:]
|
||||||
lock_class: Union[None, Any] = None,
|
return method(*args, **kwargs)
|
||||||
thread_local: bool = True,
|
|
||||||
) -> Any:
|
|
||||||
prefixed_name = cast(str, self._prefixed(name))
|
|
||||||
return super().lock(
|
|
||||||
prefixed_name,
|
|
||||||
timeout=timeout,
|
|
||||||
sleep=sleep,
|
|
||||||
blocking=blocking,
|
|
||||||
blocking_timeout=blocking_timeout,
|
|
||||||
lock_class=lock_class,
|
|
||||||
thread_local=thread_local,
|
|
||||||
)
|
|
||||||
|
|
||||||
def incrby(self, name: KeyT, amount: int = 1) -> ResponseT:
|
return wrapper
|
||||||
"""
|
|
||||||
Increments the value of ``key`` by ``amount``. If no key exists,
|
|
||||||
the value will be initialized as ``amount``
|
|
||||||
|
|
||||||
For more information see https://redis.io/commands/incrby
|
def __getattribute__(self, item):
|
||||||
"""
|
original_attr = super().__getattribute__(item)
|
||||||
prefixed_name = self._prefixed(name)
|
methods_to_wrap = [
|
||||||
return super().incrby(prefixed_name, amount)
|
"get",
|
||||||
|
"set",
|
||||||
def set(
|
"delete",
|
||||||
self,
|
"exists",
|
||||||
name: KeyT,
|
"incrby",
|
||||||
value: EncodableT,
|
"hset",
|
||||||
ex: Union[ExpiryT, None] = None,
|
"hget",
|
||||||
px: Union[ExpiryT, None] = None,
|
"getset",
|
||||||
nx: bool = False,
|
] # Add all methods that need prefixing
|
||||||
xx: bool = False,
|
if item in methods_to_wrap and callable(original_attr):
|
||||||
keepttl: bool = False,
|
return self._prefix_method(original_attr)
|
||||||
get: bool = False,
|
return original_attr
|
||||||
exat: Union[AbsExpiryT, None] = None,
|
|
||||||
pxat: Union[AbsExpiryT, None] = None,
|
|
||||||
) -> ResponseT:
|
|
||||||
prefixed_name = self._prefixed(name)
|
|
||||||
return super().set(
|
|
||||||
prefixed_name,
|
|
||||||
value,
|
|
||||||
ex=ex,
|
|
||||||
px=px,
|
|
||||||
nx=nx,
|
|
||||||
xx=xx,
|
|
||||||
keepttl=keepttl,
|
|
||||||
get=get,
|
|
||||||
exat=exat,
|
|
||||||
pxat=pxat,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get(self, name: KeyT) -> ResponseT:
|
|
||||||
prefixed_name = self._prefixed(name)
|
|
||||||
return super().get(prefixed_name)
|
|
||||||
|
|
||||||
def delete(self, *names: KeyT) -> ResponseT:
|
|
||||||
prefixed_names = [self._prefixed(name) for name in names]
|
|
||||||
return super().delete(*prefixed_names)
|
|
||||||
|
|
||||||
def exists(self, *names: KeyT) -> ResponseT:
|
|
||||||
prefixed_names = [self._prefixed(name) for name in names]
|
|
||||||
return super().exists(*prefixed_names)
|
|
||||||
|
|
||||||
# def expire(self, name: str, time: int, **kwargs: Any) -> Any:
|
|
||||||
# prefixed_name = self._prefixed(name)
|
|
||||||
# return super().expire(prefixed_name, time, **kwargs)
|
|
||||||
|
|
||||||
# def ttl(self, name: str, **kwargs: Any) -> Any:
|
|
||||||
# prefixed_name = self._prefixed(name)
|
|
||||||
# return super().ttl(prefixed_name, **kwargs)
|
|
||||||
|
|
||||||
# def type(self, name: str, **kwargs: Any) -> Any:
|
|
||||||
# prefixed_name = self._prefixed(name)
|
|
||||||
# return super().type(prefixed_name, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisPool:
|
class RedisPool:
|
||||||
|
@@ -669,31 +669,45 @@ def create_connector_with_mock_credential(
|
|||||||
user: User = Depends(current_curator_or_admin_user),
|
user: User = Depends(current_curator_or_admin_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> StatusResponse:
|
) -> StatusResponse:
|
||||||
|
print("Starting create_connector_with_mock_credential function")
|
||||||
if user and user.role != UserRole.ADMIN:
|
if user and user.role != UserRole.ADMIN:
|
||||||
|
print(f"User role: {user.role}")
|
||||||
if connector_data.is_public:
|
if connector_data.is_public:
|
||||||
|
print("Non-admin user attempting to create public credential")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
detail="User does not have permission to create public credentials",
|
detail="User does not have permission to create public credentials",
|
||||||
)
|
)
|
||||||
if not connector_data.groups:
|
if not connector_data.groups:
|
||||||
|
print("Curator attempting to create connector without specifying groups")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
detail="Curators must specify 1+ groups",
|
detail="Curators must specify 1+ groups",
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
print(f"Validating connector: {connector_data.source}")
|
||||||
_validate_connector_allowed(connector_data.source)
|
_validate_connector_allowed(connector_data.source)
|
||||||
|
print("Creating connector")
|
||||||
connector_response = create_connector(
|
connector_response = create_connector(
|
||||||
db_session=db_session, connector_data=connector_data
|
db_session=db_session, connector_data=connector_data
|
||||||
)
|
)
|
||||||
|
print(f"Connector created with ID: {connector_response.id}")
|
||||||
|
|
||||||
|
print("Creating mock credential")
|
||||||
mock_credential = CredentialBase(
|
mock_credential = CredentialBase(
|
||||||
credential_json={}, admin_public=True, source=connector_data.source
|
credential_json={}, admin_public=True, source=connector_data.source
|
||||||
)
|
)
|
||||||
credential = create_credential(
|
credential = create_credential(
|
||||||
mock_credential, user=user, db_session=db_session
|
mock_credential, user=user, db_session=db_session
|
||||||
)
|
)
|
||||||
|
print(f"Mock credential created with ID: {credential.id}")
|
||||||
|
|
||||||
access_type = (
|
access_type = (
|
||||||
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
|
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
|
||||||
)
|
)
|
||||||
|
print(f"Access type: {access_type}")
|
||||||
|
|
||||||
|
print("Adding credential to connector")
|
||||||
response = add_credential_to_connector(
|
response = add_credential_to_connector(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
user=user,
|
user=user,
|
||||||
@@ -703,9 +717,11 @@ def create_connector_with_mock_credential(
|
|||||||
cc_pair_name=connector_data.name,
|
cc_pair_name=connector_data.name,
|
||||||
groups=connector_data.groups,
|
groups=connector_data.groups,
|
||||||
)
|
)
|
||||||
|
print("Credential added to connector successfully")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
print(f"ValueError occurred: {str(e)}")
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@@ -137,6 +137,10 @@ def create_credential_from_model(
|
|||||||
target_group_ids=credential_info.groups,
|
target_group_ids=credential_info.groups,
|
||||||
object_is_public=credential_info.curator_public,
|
object_is_public=credential_info.curator_public,
|
||||||
)
|
)
|
||||||
|
from danswer.db.models import Credential
|
||||||
|
|
||||||
|
all_credentials = db_session.query(Credential).all()
|
||||||
|
print(f"Total number of credentials: {len(all_credentials)}")
|
||||||
|
|
||||||
credential = create_credential(credential_info, user, db_session)
|
credential = create_credential(credential_info, user, db_session)
|
||||||
return ObjectCreationIdResponse(
|
return ObjectCreationIdResponse(
|
||||||
|
@@ -313,7 +313,7 @@ services:
|
|||||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
- "5433:5432"
|
||||||
volumes:
|
volumes:
|
||||||
- db_volume:/var/lib/postgresql/data
|
- db_volume:/var/lib/postgresql/data
|
||||||
|
|
||||||
|
@@ -313,7 +313,7 @@ services:
|
|||||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||||
ports:
|
ports:
|
||||||
- "5432:5432"
|
- "5433:5432"
|
||||||
volumes:
|
volumes:
|
||||||
- db_volume:/var/lib/postgresql/data
|
- db_volume:/var/lib/postgresql/data
|
||||||
|
|
||||||
|
@@ -157,7 +157,7 @@ services:
|
|||||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||||
ports:
|
ports:
|
||||||
- "5432"
|
- "5433"
|
||||||
volumes:
|
volumes:
|
||||||
- db_volume:/var/lib/postgresql/data
|
- db_volume:/var/lib/postgresql/data
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user