mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-11 13:32:16 +02:00
add some more multi tenancy
This commit is contained in:
parent
802dc00f78
commit
a2fd8d5e0a
@ -140,7 +140,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
|
@ -232,6 +232,8 @@ def create_credential(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> Credential:
|
||||
all_credentials = db_session.query(Credential).all()
|
||||
print(f"Total number of credentials: {len(all_credentials)}")
|
||||
credential = Credential(
|
||||
credential_json=credential_data.credential_json,
|
||||
user_id=user.id if user else None,
|
||||
@ -241,7 +243,12 @@ def create_credential(
|
||||
curator_public=credential_data.curator_public,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush() # This ensures the credential gets an ID
|
||||
# Query and print length of all credentials
|
||||
all_credentials = db_session.query(Credential).all()
|
||||
print(f"Total number of credentials: {len(all_credentials)}")
|
||||
db_session.flush() # This ensures the credential gets an IDcredentials
|
||||
all_credentials = db_session.query(Credential).all()
|
||||
print(f"Total number of credentials: {len(all_credentials)}")
|
||||
|
||||
_relate_credential_to_user_groups__no_commit(
|
||||
db_session=db_session,
|
||||
@ -249,7 +256,7 @@ def create_credential(
|
||||
user_group_ids=credential_data.groups,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
# db_session.commit()
|
||||
|
||||
return credential
|
||||
|
||||
|
@ -291,45 +291,57 @@ async def get_async_session_with_tenant(
|
||||
yield session
|
||||
|
||||
|
||||
class TenantSession(Session):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.tenant_id = kwargs.pop("tenant_id", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
super().__enter__()
|
||||
if self.tenant_id:
|
||||
self.execute(text(f'SET search_path TO "{self.tenant_id}"'))
|
||||
return self
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
engine = get_sqlalchemy_engine()
|
||||
if tenant_id is None:
|
||||
tenant_id = current_tenant_id.get()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
# Establish a raw connection without starting a transaction
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection
|
||||
dbapi_connection = connection.connection
|
||||
engine = get_sqlalchemy_engine()
|
||||
SessionLocal = sessionmaker(bind=engine, expire_on_commit=False, class_=Session)
|
||||
|
||||
# Create a session
|
||||
with SessionLocal() as session:
|
||||
# Attach the event listener to set the search_path
|
||||
@event.listens_for(session, "after_begin")
|
||||
def _set_search_path(session, transaction, connection, **kw):
|
||||
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
|
||||
|
||||
# Execute SET search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
# Optionally verify the search_path was set correctly
|
||||
cursor.execute("SHOW search_path")
|
||||
cursor.fetchone()
|
||||
yield session
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Proceed to create a session using the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
if MULTI_TENANT:
|
||||
# Reset search_path to default after the session is used
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
session.execute(text('SET search_path TO "$user", public'))
|
||||
|
||||
|
||||
# Optionally, attach engine-level event listener
|
||||
def set_search_path_on_checkout(dbapi_connection, connection_record, connection_proxy):
|
||||
tenant_id = current_tenant_id.get()
|
||||
if tenant_id and is_valid_schema_name(tenant_id):
|
||||
with dbapi_connection.cursor() as cursor:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
|
||||
def get_session_generator_with_tenant(
|
||||
|
@ -1,16 +1,9 @@
|
||||
import functools
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import redis
|
||||
from redis.client import Redis
|
||||
from redis.typing import AbsExpiryT
|
||||
from redis.typing import EncodableT
|
||||
from redis.typing import ExpiryT
|
||||
from redis.typing import KeyT
|
||||
from redis.typing import ResponseT
|
||||
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER
|
||||
from danswer.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
@ -24,15 +17,12 @@ from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
|
||||
|
||||
# TODO: enforce typing strictly
|
||||
class TenantRedis(redis.Redis):
|
||||
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any):
|
||||
def __init__(self, tenant_id: str, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def _prefixed(
|
||||
self, key: Union[str, bytes, memoryview]
|
||||
) -> Union[str, bytes, memoryview]:
|
||||
def _prefixed(self, key):
|
||||
prefix = f"{self.tenant_id}:"
|
||||
if isinstance(key, str):
|
||||
return prefix + key
|
||||
@ -43,87 +33,32 @@ class TenantRedis(redis.Redis):
|
||||
else:
|
||||
raise TypeError(f"Unsupported key type: {type(key)}")
|
||||
|
||||
def lock(
|
||||
self,
|
||||
name: str,
|
||||
timeout: Optional[float] = None,
|
||||
sleep: float = 0.1,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: Optional[float] = None,
|
||||
lock_class: Union[None, Any] = None,
|
||||
thread_local: bool = True,
|
||||
) -> Any:
|
||||
prefixed_name = cast(str, self._prefixed(name))
|
||||
return super().lock(
|
||||
prefixed_name,
|
||||
timeout=timeout,
|
||||
sleep=sleep,
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
lock_class=lock_class,
|
||||
thread_local=thread_local,
|
||||
)
|
||||
def _prefix_method(self, method):
|
||||
@functools.wraps(method)
|
||||
def wrapper(*args, **kwargs):
|
||||
if "name" in kwargs:
|
||||
kwargs["name"] = self._prefixed(kwargs["name"])
|
||||
elif len(args) > 0:
|
||||
args = (self._prefixed(args[0]),) + args[1:]
|
||||
return method(*args, **kwargs)
|
||||
|
||||
def incrby(self, name: KeyT, amount: int = 1) -> ResponseT:
|
||||
"""
|
||||
Increments the value of ``key`` by ``amount``. If no key exists,
|
||||
the value will be initialized as ``amount``
|
||||
return wrapper
|
||||
|
||||
For more information see https://redis.io/commands/incrby
|
||||
"""
|
||||
prefixed_name = self._prefixed(name)
|
||||
return super().incrby(prefixed_name, amount)
|
||||
|
||||
def set(
|
||||
self,
|
||||
name: KeyT,
|
||||
value: EncodableT,
|
||||
ex: Union[ExpiryT, None] = None,
|
||||
px: Union[ExpiryT, None] = None,
|
||||
nx: bool = False,
|
||||
xx: bool = False,
|
||||
keepttl: bool = False,
|
||||
get: bool = False,
|
||||
exat: Union[AbsExpiryT, None] = None,
|
||||
pxat: Union[AbsExpiryT, None] = None,
|
||||
) -> ResponseT:
|
||||
prefixed_name = self._prefixed(name)
|
||||
return super().set(
|
||||
prefixed_name,
|
||||
value,
|
||||
ex=ex,
|
||||
px=px,
|
||||
nx=nx,
|
||||
xx=xx,
|
||||
keepttl=keepttl,
|
||||
get=get,
|
||||
exat=exat,
|
||||
pxat=pxat,
|
||||
)
|
||||
|
||||
def get(self, name: KeyT) -> ResponseT:
|
||||
prefixed_name = self._prefixed(name)
|
||||
return super().get(prefixed_name)
|
||||
|
||||
def delete(self, *names: KeyT) -> ResponseT:
|
||||
prefixed_names = [self._prefixed(name) for name in names]
|
||||
return super().delete(*prefixed_names)
|
||||
|
||||
def exists(self, *names: KeyT) -> ResponseT:
|
||||
prefixed_names = [self._prefixed(name) for name in names]
|
||||
return super().exists(*prefixed_names)
|
||||
|
||||
# def expire(self, name: str, time: int, **kwargs: Any) -> Any:
|
||||
# prefixed_name = self._prefixed(name)
|
||||
# return super().expire(prefixed_name, time, **kwargs)
|
||||
|
||||
# def ttl(self, name: str, **kwargs: Any) -> Any:
|
||||
# prefixed_name = self._prefixed(name)
|
||||
# return super().ttl(prefixed_name, **kwargs)
|
||||
|
||||
# def type(self, name: str, **kwargs: Any) -> Any:
|
||||
# prefixed_name = self._prefixed(name)
|
||||
# return super().type(prefixed_name, **kwargs)
|
||||
def __getattribute__(self, item):
|
||||
original_attr = super().__getattribute__(item)
|
||||
methods_to_wrap = [
|
||||
"get",
|
||||
"set",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
"hset",
|
||||
"hget",
|
||||
"getset",
|
||||
] # Add all methods that need prefixing
|
||||
if item in methods_to_wrap and callable(original_attr):
|
||||
return self._prefix_method(original_attr)
|
||||
return original_attr
|
||||
|
||||
|
||||
class RedisPool:
|
||||
|
@ -669,31 +669,45 @@ def create_connector_with_mock_credential(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
print("Starting create_connector_with_mock_credential function")
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
print(f"User role: {user.role}")
|
||||
if connector_data.is_public:
|
||||
print("Non-admin user attempting to create public credential")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="User does not have permission to create public credentials",
|
||||
)
|
||||
if not connector_data.groups:
|
||||
print("Curator attempting to create connector without specifying groups")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Curators must specify 1+ groups",
|
||||
)
|
||||
try:
|
||||
print(f"Validating connector: {connector_data.source}")
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
print("Creating connector")
|
||||
connector_response = create_connector(
|
||||
db_session=db_session, connector_data=connector_data
|
||||
)
|
||||
print(f"Connector created with ID: {connector_response.id}")
|
||||
|
||||
print("Creating mock credential")
|
||||
mock_credential = CredentialBase(
|
||||
credential_json={}, admin_public=True, source=connector_data.source
|
||||
)
|
||||
credential = create_credential(
|
||||
mock_credential, user=user, db_session=db_session
|
||||
)
|
||||
print(f"Mock credential created with ID: {credential.id}")
|
||||
|
||||
access_type = (
|
||||
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
|
||||
)
|
||||
print(f"Access type: {access_type}")
|
||||
|
||||
print("Adding credential to connector")
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@ -703,9 +717,11 @@ def create_connector_with_mock_credential(
|
||||
cc_pair_name=connector_data.name,
|
||||
groups=connector_data.groups,
|
||||
)
|
||||
print("Credential added to connector successfully")
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
print(f"ValueError occurred: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
|
@ -137,6 +137,10 @@ def create_credential_from_model(
|
||||
target_group_ids=credential_info.groups,
|
||||
object_is_public=credential_info.curator_public,
|
||||
)
|
||||
from danswer.db.models import Credential
|
||||
|
||||
all_credentials = db_session.query(Credential).all()
|
||||
print(f"Total number of credentials: {len(all_credentials)}")
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
return ObjectCreationIdResponse(
|
||||
|
@ -313,7 +313,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
@ -313,7 +313,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
@ -157,7 +157,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432"
|
||||
- "5433"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user