Private personas doc sets (#52)

Private Personas and Document Sets

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
Chris Weaver 2024-03-21 13:51:38 -07:00
parent 680482bd06
commit 17cc262f5d
19 changed files with 257 additions and 156 deletions

View File

@ -311,14 +311,14 @@ def kickoff_indexing_jobs(
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False
pure=False,
)
else:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
global_version.get_is_ee_version(),
pure=False
pure=False,
)
if run:

View File

@ -1129,6 +1129,7 @@ class PGFileStore(Base):
lobj_oid: Mapped[int] = mapped_column(Integer, nullable=False)
"""
************************************************************************
Enterprise Edition Models

View File

@ -0,0 +1,123 @@
from uuid import UUID
from sqlalchemy.orm import Session
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import DocumentSet
from danswer.db.models import DocumentSet__ConnectorCredentialPair
from danswer.db.models import DocumentSet__User
from danswer.db.models import DocumentSet__UserGroup
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
def make_doc_set_private(
document_set_id: int,
user_ids: list[UUID] | None,
group_ids: list[int] | None,
db_session: Session,
) -> None:
db_session.query(DocumentSet__User).filter(
DocumentSet__User.document_set_id == document_set_id
).delete(synchronize_session="fetch")
db_session.query(DocumentSet__UserGroup).filter(
DocumentSet__UserGroup.document_set_id == document_set_id
).delete(synchronize_session="fetch")
if user_ids:
for user_uuid in user_ids:
db_session.add(
DocumentSet__User(document_set_id=document_set_id, user_id=user_uuid)
)
if group_ids:
for group_id in group_ids:
db_session.add(
DocumentSet__UserGroup(
document_set_id=document_set_id, user_group_id=group_id
)
)
def delete_document_set_privacy__no_commit(
document_set_id: int, db_session: Session
) -> None:
db_session.query(DocumentSet__User).filter(
DocumentSet__User.document_set_id == document_set_id
).delete(synchronize_session="fetch")
db_session.query(DocumentSet__UserGroup).filter(
DocumentSet__UserGroup.document_set_id == document_set_id
).delete(synchronize_session="fetch")
def fetch_document_sets(
user_id: UUID | None,
db_session: Session,
include_outdated: bool = True, # Parameter only for versioned implementation, unused
) -> list[tuple[DocumentSet, list[ConnectorCredentialPair]]]:
assert user_id is not None
# Public document sets
public_document_sets = (
db_session.query(DocumentSet)
.filter(DocumentSet.is_public == True) # noqa
.all()
)
# Document sets via shared user relationships
shared_document_sets = (
db_session.query(DocumentSet)
.join(DocumentSet__User, DocumentSet.id == DocumentSet__User.document_set_id)
.filter(DocumentSet__User.user_id == user_id)
.all()
)
# Document sets via groups
# First, find the user groups the user belongs to
user_groups = (
db_session.query(UserGroup)
.join(User__UserGroup, UserGroup.id == User__UserGroup.user_group_id)
.filter(User__UserGroup.user_id == user_id)
.all()
)
group_document_sets = []
for group in user_groups:
group_document_sets.extend(
db_session.query(DocumentSet)
.join(
DocumentSet__UserGroup,
DocumentSet.id == DocumentSet__UserGroup.document_set_id,
)
.filter(DocumentSet__UserGroup.user_group_id == group.id)
.all()
)
# Combine and deduplicate document sets from all sources
all_document_sets = list(
set(public_document_sets + shared_document_sets + group_document_sets)
)
document_set_with_cc_pairs: list[
tuple[DocumentSet, list[ConnectorCredentialPair]]
] = []
for document_set in all_document_sets:
# Fetch the associated ConnectorCredentialPairs
cc_pairs = (
db_session.query(ConnectorCredentialPair)
.join(
DocumentSet__ConnectorCredentialPair,
ConnectorCredentialPair.id
== DocumentSet__ConnectorCredentialPair.connector_credential_pair_id,
)
.filter(
DocumentSet__ConnectorCredentialPair.document_set_id == document_set.id,
)
.all()
)
document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore
return document_set_with_cc_pairs

View File

@ -1,96 +0,0 @@
import datetime
from uuid import UUID
from sqlalchemy import Boolean
from sqlalchemy import DateTime
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import String
from sqlalchemy import Text
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from danswer.db.models import Base
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import User
class SamlAccount(Base):
__tablename__ = "saml"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True)
encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True)
expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
user: Mapped[User] = relationship("User")
"""Tables related to RBAC"""
class User__UserGroup(Base):
__tablename__ = "user__user_group"
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
class UserGroup__ConnectorCredentialPair(Base):
__tablename__ = "user_group__connector_credential_pair"
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
cc_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"), primary_key=True
)
# if `True`, then is part of the current state of the UserGroup
# if `False`, then is a part of the prior state of the UserGroup
# rows with `is_current=False` should be deleted when the UserGroup
# is updated and should not exist for a given UserGroup if
# `UserGroup.is_up_to_date == True`
is_current: Mapped[bool] = mapped_column(
Boolean,
default=True,
primary_key=True,
)
cc_pair: Mapped[ConnectorCredentialPair] = relationship(
"ConnectorCredentialPair",
)
class UserGroup(Base):
__tablename__ = "user_group"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
# whether or not changes to the UserGroup have been propogated to Vespa
is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# tell the sync job to clean up the group
is_up_for_deletion: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
users: Mapped[list[User]] = relationship(
"User",
secondary=User__UserGroup.__table__,
)
cc_pairs: Mapped[list[ConnectorCredentialPair]] = relationship(
"ConnectorCredentialPair",
secondary=UserGroup__ConnectorCredentialPair.__table__,
viewonly=True,
)
cc_pair_relationships: Mapped[
list[UserGroup__ConnectorCredentialPair]
] = relationship(
"UserGroup__ConnectorCredentialPair",
viewonly=True,
)

View File

@ -0,0 +1,32 @@
from uuid import UUID
from sqlalchemy.orm import Session
from danswer.db.models import Persona__User
from danswer.db.models import Persona__UserGroup
def make_persona_private(
persona_id: int,
user_ids: list[UUID] | None,
group_ids: list[int] | None,
db_session: Session,
) -> None:
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.persona_id == persona_id
).delete(synchronize_session="fetch")
if user_ids:
for user_uuid in user_ids:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
if group_ids:
for group_id in group_ids:
db_session.add(
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
)
db_session.commit()

View File

@ -8,8 +8,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.db.models import SamlAccount
from danswer.db.models import User
from ee.danswer.db.models import SamlAccount
def upsert_saml_account(

View File

@ -10,10 +10,10 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from ee.danswer.db.models import User__UserGroup
from ee.danswer.db.models import UserGroup
from ee.danswer.db.models import UserGroup__ConnectorCredentialPair
from ee.danswer.server.user_group.models import UserGroupCreate
from ee.danswer.server.user_group.models import UserGroupUpdate

View File

@ -1,14 +1,14 @@
from uuid import UUID
from pydantic import BaseModel
from danswer.server.documents.models import (
ConnectorCredentialPairDescriptor,
ConnectorSnapshot,
CredentialSnapshot,
)
from danswer.server.manage.models import UserInfo
from ee.danswer.db.models import UserGroup as UserGroupModel
from danswer.db.models import UserGroup as UserGroupModel
from danswer.server.documents.models import ConnectorCredentialPairDescriptor
from danswer.server.documents.models import ConnectorSnapshot
from danswer.server.documents.models import CredentialSnapshot
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.manage.models import UserInfo
class UserGroup(BaseModel):
@ -16,14 +16,16 @@ class UserGroup(BaseModel):
name: str
users: list[UserInfo]
cc_pairs: list[ConnectorCredentialPairDescriptor]
document_sets: list[DocumentSet]
personas: list[PersonaSnapshot]
is_up_to_date: bool
is_up_for_deletion: bool
@classmethod
def from_model(cls, document_set_model: UserGroupModel) -> "UserGroup":
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
return cls(
id=document_set_model.id,
name=document_set_model.name,
id=user_group_model.id,
name=user_group_model.name,
users=[
UserInfo(
id=str(user.id),
@ -33,7 +35,7 @@ class UserGroup(BaseModel):
is_verified=user.is_verified,
role=user.role,
)
for user in document_set_model.users
for user in user_group_model.users
],
cc_pairs=[
ConnectorCredentialPairDescriptor(
@ -46,11 +48,18 @@ class UserGroup(BaseModel):
cc_pair_relationship.cc_pair.credential
),
)
for cc_pair_relationship in document_set_model.cc_pair_relationships
for cc_pair_relationship in user_group_model.cc_pair_relationships
if cc_pair_relationship.is_current
],
is_up_to_date=document_set_model.is_up_to_date,
is_up_for_deletion=document_set_model.is_up_for_deletion,
document_sets=[
DocumentSet.from_model(ds) for ds in user_group_model.document_sets
],
personas=[
PersonaSnapshot.from_model(persona)
for persona in user_group_model.personas
],
is_up_to_date=user_group_model.is_up_to_date,
is_up_for_deletion=user_group_model.is_up_for_deletion,
)

View File

@ -21,7 +21,7 @@ def run_jobs(exclude_indexing: bool) -> None:
cmd_worker = [
"celery",
"-A",
"danswer.background.celery",
"ee.danswer.background.celery",
"worker",
"--pool=threads",
"--autoscale=3,10",
@ -29,7 +29,13 @@ def run_jobs(exclude_indexing: bool) -> None:
"--concurrency=1",
]
cmd_beat = ["celery", "-A", "danswer.background.celery", "beat", "--loglevel=INFO"]
cmd_beat = [
"celery",
"-A",
"ee.danswer.background.celery",
"beat",
"--loglevel=INFO",
]
worker_process = subprocess.Popen(
cmd_worker, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True

View File

@ -1,3 +1,4 @@
import { EE_ENABLED } from "@/lib/constants";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { DocumentSet } from "@/lib/types";
import useSWR, { mutate } from "swr";

View File

@ -1,11 +1,10 @@
import { Form, Formik } from "formik";
import * as Yup from "yup";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { ConnectorIndexingStatus, User } from "@/lib/types";
import { ConnectorIndexingStatus, User, UserGroup } from "@/lib/types";
import { TextFormField } from "@/components/admin/connectors/Field";
import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle";
import { createUserGroup } from "./lib";
import { UserGroup } from "./types";
import { UserEditor } from "./UserEditor";
import { ConnectorEditor } from "./ConnectorEditor";
import { Modal } from "@/components/Modal";

View File

@ -8,7 +8,6 @@ import {
TableBody,
TableCell,
} from "@tremor/react";
import { UserGroup } from "./types";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { LoadingAnimation } from "@/components/Loading";
import { BasicTable } from "@/components/admin/connectors/BasicTable";
@ -17,7 +16,7 @@ import { TrashIcon } from "@/components/icons/icons";
import { deleteUserGroup } from "./lib";
import { useRouter } from "next/navigation";
import { FiEdit, FiUser } from "react-icons/fi";
import { User } from "@/lib/types";
import { User, UserGroup } from "@/lib/types";
import Link from "next/link";
import { DeleteButton } from "@/components/DeleteButton";

View File

@ -5,9 +5,8 @@ import { UsersIcon } from "@/components/icons/icons";
import { useState } from "react";
import { FiPlus, FiX } from "react-icons/fi";
import { updateUserGroup } from "./lib";
import { UserGroup } from "../types";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { Connector, ConnectorIndexingStatus } from "@/lib/types";
import { Connector, ConnectorIndexingStatus, UserGroup } from "@/lib/types";
import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle";
interface AddConnectorFormProps {

View File

@ -1,8 +1,7 @@
import { Modal } from "@/components/Modal";
import { updateUserGroup } from "./lib";
import { UserGroup } from "../types";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { User } from "@/lib/types";
import { User, UserGroup } from "@/lib/types";
import { UserEditor } from "../UserEditor";
import { useState } from "react";

View File

@ -2,12 +2,11 @@
import { usePopup } from "@/components/admin/connectors/Popup";
import { useState } from "react";
import { UserGroup } from "../types";
import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle";
import { AddMemberForm } from "./AddMemberForm";
import { updateUserGroup } from "./lib";
import { LoadingAnimation } from "@/components/Loading";
import { ConnectorIndexingStatus, User } from "@/lib/types";
import { ConnectorIndexingStatus, User, UserGroup } from "@/lib/types";
import { AddConnectorForm } from "./AddConnectorForm";
import {
Table,
@ -21,6 +20,8 @@ import {
Text,
} from "@tremor/react";
import { DeleteButton } from "@/components/DeleteButton";
import { Bubble } from "@/components/Bubble";
import { BookmarkIcon, RobotIcon } from "@/components/icons/icons";
interface GroupDisplayProps {
users: User[];
@ -250,6 +251,56 @@ export const GroupDisplay = ({
setPopup={setPopup}
/>
)}
<Divider />
<h2 className="text-xl font-bold mt-8 mb-2">Document Sets</h2>
<div>
{userGroup.document_sets.length > 0 ? (
<div className="flex flex-wrap gap-2">
{userGroup.document_sets.map((documentSet) => {
return (
<Bubble isSelected key={documentSet.id}>
<div className="flex">
<BookmarkIcon />
<Text className="ml-1">{documentSet.name}</Text>
</div>
</Bubble>
);
})}
</div>
) : (
<>
<Text>No document sets in this group...</Text>
</>
)}
</div>
<Divider />
<h2 className="text-xl font-bold mt-8 mb-2">Personas</h2>
<div>
{userGroup.document_sets.length > 0 ? (
<div className="flex flex-wrap gap-2">
{userGroup.personas.map((persona) => {
return (
<Bubble isSelected key={persona.id}>
<div className="flex">
<RobotIcon />
<Text className="ml-1">{persona.name}</Text>
</div>
</Bubble>
);
})}
</div>
) : (
<>
<Text>No Personas in this group...</Text>
</>
)}
</div>
</div>
);
};

View File

@ -1,4 +1,4 @@
import { useUserGroups } from "../hooks";
import { useUserGroups } from "@/lib/hooks";
export const useSpecificUserGroup = (groupId: string) => {
const { data, isLoading, error, refreshUserGroups } = useUserGroups();

View File

@ -1,14 +0,0 @@
import useSWR, { mutate } from "swr";
import { UserGroup } from "./types";
import { errorHandlingFetcher } from "@/lib/fetcher";
const USER_GROUP_URL = "/api/manage/admin/user-group";
export const useUserGroups = () => {
const swrResponse = useSWR<UserGroup[]>(USER_GROUP_URL, errorHandlingFetcher);
return {
...swrResponse,
refreshUserGroups: () => mutate(USER_GROUP_URL),
};
};

View File

@ -6,8 +6,11 @@ import { UserGroupCreationForm } from "./UserGroupCreationForm";
import { usePopup } from "@/components/admin/connectors/Popup";
import { useState } from "react";
import { ThreeDotsLoader } from "@/components/Loading";
import { useConnectorCredentialIndexingStatus, useUsers } from "@/lib/hooks";
import { useUserGroups } from "./hooks";
import {
useConnectorCredentialIndexingStatus,
useUserGroups,
useUsers,
} from "@/lib/hooks";
import { AdminPageTitle } from "@/components/admin/Title";
import { Button, Divider } from "@tremor/react";

View File

@ -1,19 +1,8 @@
import { CCPairDescriptor, User } from "@/lib/types";
export interface UserGroupUpdate {
user_ids: string[];
cc_pair_ids: number[];
}
export interface UserGroup {
id: number;
name: string;
users: User[];
cc_pairs: CCPairDescriptor<any, any>[];
is_up_to_date: boolean;
is_up_for_deletion: boolean;
}
export interface UserGroupCreation {
name: string;
user_ids: string[];