mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-31 16:08:04 +02:00
124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
from uuid import UUID
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.db.models import ConnectorCredentialPair
|
|
from onyx.db.models import DocumentSet
|
|
from onyx.db.models import DocumentSet__ConnectorCredentialPair
|
|
from onyx.db.models import DocumentSet__User
|
|
from onyx.db.models import DocumentSet__UserGroup
|
|
from onyx.db.models import User__UserGroup
|
|
from onyx.db.models import UserGroup
|
|
|
|
|
|
def make_doc_set_private(
|
|
document_set_id: int,
|
|
user_ids: list[UUID] | None,
|
|
group_ids: list[int] | None,
|
|
db_session: Session,
|
|
) -> None:
|
|
db_session.query(DocumentSet__User).filter(
|
|
DocumentSet__User.document_set_id == document_set_id
|
|
).delete(synchronize_session="fetch")
|
|
db_session.query(DocumentSet__UserGroup).filter(
|
|
DocumentSet__UserGroup.document_set_id == document_set_id
|
|
).delete(synchronize_session="fetch")
|
|
|
|
if user_ids:
|
|
for user_uuid in user_ids:
|
|
db_session.add(
|
|
DocumentSet__User(document_set_id=document_set_id, user_id=user_uuid)
|
|
)
|
|
|
|
if group_ids:
|
|
for group_id in group_ids:
|
|
db_session.add(
|
|
DocumentSet__UserGroup(
|
|
document_set_id=document_set_id, user_group_id=group_id
|
|
)
|
|
)
|
|
|
|
|
|
def delete_document_set_privacy__no_commit(
|
|
document_set_id: int, db_session: Session
|
|
) -> None:
|
|
db_session.query(DocumentSet__User).filter(
|
|
DocumentSet__User.document_set_id == document_set_id
|
|
).delete(synchronize_session="fetch")
|
|
|
|
db_session.query(DocumentSet__UserGroup).filter(
|
|
DocumentSet__UserGroup.document_set_id == document_set_id
|
|
).delete(synchronize_session="fetch")
|
|
|
|
|
|
def fetch_document_sets(
|
|
user_id: UUID | None,
|
|
db_session: Session,
|
|
include_outdated: bool = True, # Parameter only for versioned implementation, unused
|
|
) -> list[tuple[DocumentSet, list[ConnectorCredentialPair]]]:
|
|
assert user_id is not None
|
|
|
|
# Public document sets
|
|
public_document_sets = (
|
|
db_session.query(DocumentSet)
|
|
.filter(DocumentSet.is_public == True) # noqa
|
|
.all()
|
|
)
|
|
|
|
# Document sets via shared user relationships
|
|
shared_document_sets = (
|
|
db_session.query(DocumentSet)
|
|
.join(DocumentSet__User, DocumentSet.id == DocumentSet__User.document_set_id)
|
|
.filter(DocumentSet__User.user_id == user_id)
|
|
.all()
|
|
)
|
|
|
|
# Document sets via groups
|
|
# First, find the user groups the user belongs to
|
|
user_groups = (
|
|
db_session.query(UserGroup)
|
|
.join(User__UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
|
.filter(User__UserGroup.user_id == user_id)
|
|
.all()
|
|
)
|
|
|
|
group_document_sets = []
|
|
for group in user_groups:
|
|
group_document_sets.extend(
|
|
db_session.query(DocumentSet)
|
|
.join(
|
|
DocumentSet__UserGroup,
|
|
DocumentSet.id == DocumentSet__UserGroup.document_set_id,
|
|
)
|
|
.filter(DocumentSet__UserGroup.user_group_id == group.id)
|
|
.all()
|
|
)
|
|
|
|
# Combine and deduplicate document sets from all sources
|
|
all_document_sets = list(
|
|
set(public_document_sets + shared_document_sets + group_document_sets)
|
|
)
|
|
|
|
document_set_with_cc_pairs: list[
|
|
tuple[DocumentSet, list[ConnectorCredentialPair]]
|
|
] = []
|
|
|
|
for document_set in all_document_sets:
|
|
# Fetch the associated ConnectorCredentialPairs
|
|
cc_pairs = (
|
|
db_session.query(ConnectorCredentialPair)
|
|
.join(
|
|
DocumentSet__ConnectorCredentialPair,
|
|
ConnectorCredentialPair.id
|
|
== DocumentSet__ConnectorCredentialPair.connector_credential_pair_id,
|
|
)
|
|
.filter(
|
|
DocumentSet__ConnectorCredentialPair.document_set_id == document_set.id,
|
|
)
|
|
.all()
|
|
)
|
|
|
|
document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore
|
|
|
|
return document_set_with_cc_pairs
|