mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
Private personas doc sets (#52)
Private Personas and Document Sets --------- Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
123
backend/ee/danswer/db/document_set.py
Normal file
123
backend/ee/danswer/db/document_set.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import DocumentSet__ConnectorCredentialPair
|
||||
from danswer.db.models import DocumentSet__User
|
||||
from danswer.db.models import DocumentSet__UserGroup
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup
|
||||
|
||||
|
||||
def make_doc_set_private(
|
||||
document_set_id: int,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.document_set_id == document_set_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(DocumentSet__UserGroup).filter(
|
||||
DocumentSet__UserGroup.document_set_id == document_set_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
for user_uuid in user_ids:
|
||||
db_session.add(
|
||||
DocumentSet__User(document_set_id=document_set_id, user_id=user_uuid)
|
||||
)
|
||||
|
||||
if group_ids:
|
||||
for group_id in group_ids:
|
||||
db_session.add(
|
||||
DocumentSet__UserGroup(
|
||||
document_set_id=document_set_id, user_group_id=group_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def delete_document_set_privacy__no_commit(
|
||||
document_set_id: int, db_session: Session
|
||||
) -> None:
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.document_set_id == document_set_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
db_session.query(DocumentSet__UserGroup).filter(
|
||||
DocumentSet__UserGroup.document_set_id == document_set_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
|
||||
def fetch_document_sets(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
include_outdated: bool = True, # Parameter only for versioned implementation, unused
|
||||
) -> list[tuple[DocumentSet, list[ConnectorCredentialPair]]]:
|
||||
assert user_id is not None
|
||||
|
||||
# Public document sets
|
||||
public_document_sets = (
|
||||
db_session.query(DocumentSet)
|
||||
.filter(DocumentSet.is_public == True) # noqa
|
||||
.all()
|
||||
)
|
||||
|
||||
# Document sets via shared user relationships
|
||||
shared_document_sets = (
|
||||
db_session.query(DocumentSet)
|
||||
.join(DocumentSet__User, DocumentSet.id == DocumentSet__User.document_set_id)
|
||||
.filter(DocumentSet__User.user_id == user_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Document sets via groups
|
||||
# First, find the user groups the user belongs to
|
||||
user_groups = (
|
||||
db_session.query(UserGroup)
|
||||
.join(User__UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.filter(User__UserGroup.user_id == user_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
group_document_sets = []
|
||||
for group in user_groups:
|
||||
group_document_sets.extend(
|
||||
db_session.query(DocumentSet)
|
||||
.join(
|
||||
DocumentSet__UserGroup,
|
||||
DocumentSet.id == DocumentSet__UserGroup.document_set_id,
|
||||
)
|
||||
.filter(DocumentSet__UserGroup.user_group_id == group.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Combine and deduplicate document sets from all sources
|
||||
all_document_sets = list(
|
||||
set(public_document_sets + shared_document_sets + group_document_sets)
|
||||
)
|
||||
|
||||
document_set_with_cc_pairs: list[
|
||||
tuple[DocumentSet, list[ConnectorCredentialPair]]
|
||||
] = []
|
||||
|
||||
for document_set in all_document_sets:
|
||||
# Fetch the associated ConnectorCredentialPairs
|
||||
cc_pairs = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.join(
|
||||
DocumentSet__ConnectorCredentialPair,
|
||||
ConnectorCredentialPair.id
|
||||
== DocumentSet__ConnectorCredentialPair.connector_credential_pair_id,
|
||||
)
|
||||
.filter(
|
||||
DocumentSet__ConnectorCredentialPair.document_set_id == document_set.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore
|
||||
|
||||
return document_set_with_cc_pairs
|
@@ -1,96 +0,0 @@
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from danswer.db.models import Base
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
class SamlAccount(Base):
|
||||
__tablename__ = "saml"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True)
|
||||
encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True)
|
||||
expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User")
|
||||
|
||||
|
||||
"""Tables related to RBAC"""
|
||||
|
||||
|
||||
class User__UserGroup(Base):
|
||||
__tablename__ = "user__user_group"
|
||||
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
|
||||
|
||||
class UserGroup__ConnectorCredentialPair(Base):
|
||||
__tablename__ = "user_group__connector_credential_pair"
|
||||
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
cc_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), primary_key=True
|
||||
)
|
||||
# if `True`, then is part of the current state of the UserGroup
|
||||
# if `False`, then is a part of the prior state of the UserGroup
|
||||
# rows with `is_current=False` should be deleted when the UserGroup
|
||||
# is updated and should not exist for a given UserGroup if
|
||||
# `UserGroup.is_up_to_date == True`
|
||||
is_current: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=True,
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
cc_pair: Mapped[ConnectorCredentialPair] = relationship(
|
||||
"ConnectorCredentialPair",
|
||||
)
|
||||
|
||||
|
||||
class UserGroup(Base):
|
||||
__tablename__ = "user_group"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
# whether or not changes to the UserGroup have been propogated to Vespa
|
||||
is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
# tell the sync job to clean up the group
|
||||
is_up_for_deletion: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
|
||||
users: Mapped[list[User]] = relationship(
|
||||
"User",
|
||||
secondary=User__UserGroup.__table__,
|
||||
)
|
||||
cc_pairs: Mapped[list[ConnectorCredentialPair]] = relationship(
|
||||
"ConnectorCredentialPair",
|
||||
secondary=UserGroup__ConnectorCredentialPair.__table__,
|
||||
viewonly=True,
|
||||
)
|
||||
cc_pair_relationships: Mapped[
|
||||
list[UserGroup__ConnectorCredentialPair]
|
||||
] = relationship(
|
||||
"UserGroup__ConnectorCredentialPair",
|
||||
viewonly=True,
|
||||
)
|
32
backend/ee/danswer/db/persona.py
Normal file
32
backend/ee/danswer/db/persona.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import Persona__User
|
||||
from danswer.db.models import Persona__UserGroup
|
||||
|
||||
|
||||
def make_persona_private(
|
||||
persona_id: int,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
for user_uuid in user_ids:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
|
||||
|
||||
if group_ids:
|
||||
for group_id in group_ids:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
db_session.commit()
|
@@ -8,8 +8,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.db.models import SamlAccount
|
||||
from danswer.db.models import User
|
||||
from ee.danswer.db.models import SamlAccount
|
||||
|
||||
|
||||
def upsert_saml_account(
|
||||
|
@@ -10,10 +10,10 @@ from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from ee.danswer.db.models import User__UserGroup
|
||||
from ee.danswer.db.models import UserGroup
|
||||
from ee.danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from ee.danswer.server.user_group.models import UserGroupCreate
|
||||
from ee.danswer.server.user_group.models import UserGroupUpdate
|
||||
|
||||
|
@@ -1,14 +1,14 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from danswer.server.documents.models import (
|
||||
ConnectorCredentialPairDescriptor,
|
||||
ConnectorSnapshot,
|
||||
CredentialSnapshot,
|
||||
)
|
||||
from danswer.server.manage.models import UserInfo
|
||||
|
||||
from ee.danswer.db.models import UserGroup as UserGroupModel
|
||||
from danswer.db.models import UserGroup as UserGroupModel
|
||||
from danswer.server.documents.models import ConnectorCredentialPairDescriptor
|
||||
from danswer.server.documents.models import ConnectorSnapshot
|
||||
from danswer.server.documents.models import CredentialSnapshot
|
||||
from danswer.server.features.document_set.models import DocumentSet
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.manage.models import UserInfo
|
||||
|
||||
|
||||
class UserGroup(BaseModel):
|
||||
@@ -16,14 +16,16 @@ class UserGroup(BaseModel):
|
||||
name: str
|
||||
users: list[UserInfo]
|
||||
cc_pairs: list[ConnectorCredentialPairDescriptor]
|
||||
document_sets: list[DocumentSet]
|
||||
personas: list[PersonaSnapshot]
|
||||
is_up_to_date: bool
|
||||
is_up_for_deletion: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, document_set_model: UserGroupModel) -> "UserGroup":
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
|
||||
return cls(
|
||||
id=document_set_model.id,
|
||||
name=document_set_model.name,
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
users=[
|
||||
UserInfo(
|
||||
id=str(user.id),
|
||||
@@ -33,7 +35,7 @@ class UserGroup(BaseModel):
|
||||
is_verified=user.is_verified,
|
||||
role=user.role,
|
||||
)
|
||||
for user in document_set_model.users
|
||||
for user in user_group_model.users
|
||||
],
|
||||
cc_pairs=[
|
||||
ConnectorCredentialPairDescriptor(
|
||||
@@ -46,11 +48,18 @@ class UserGroup(BaseModel):
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
),
|
||||
)
|
||||
for cc_pair_relationship in document_set_model.cc_pair_relationships
|
||||
for cc_pair_relationship in user_group_model.cc_pair_relationships
|
||||
if cc_pair_relationship.is_current
|
||||
],
|
||||
is_up_to_date=document_set_model.is_up_to_date,
|
||||
is_up_for_deletion=document_set_model.is_up_for_deletion,
|
||||
document_sets=[
|
||||
DocumentSet.from_model(ds) for ds in user_group_model.document_sets
|
||||
],
|
||||
personas=[
|
||||
PersonaSnapshot.from_model(persona)
|
||||
for persona in user_group_model.personas
|
||||
],
|
||||
is_up_to_date=user_group_model.is_up_to_date,
|
||||
is_up_for_deletion=user_group_model.is_up_for_deletion,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user