Move is_public from Credential to ConnectorCredentialPair (#523)

This commit is contained in:
Chris Weaver 2023-10-05 20:55:41 -07:00 committed by GitHub
parent a85e73edbe
commit 9c89ae78ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 136 additions and 61 deletions

View File

@ -0,0 +1,49 @@
"""Move is_public to cc_pair
Revision ID: 3b25685ff73c
Revises: e0a68a81d434
Create Date: 2023-10-05 18:47:09.582849
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3b25685ff73c"
down_revision = "e0a68a81d434"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column("is_public", sa.Boolean(), nullable=True),
)
# fill in is_public for existing rows
op.execute(
"UPDATE connector_credential_pair SET is_public = true WHERE is_public IS NULL"
)
op.alter_column("connector_credential_pair", "is_public", nullable=False)
op.add_column(
"credential",
sa.Column("is_admin", sa.Boolean(), nullable=True),
)
op.execute("UPDATE credential SET is_admin = true WHERE is_admin IS NULL")
op.alter_column("credential", "is_admin", nullable=False)
op.drop_column("credential", "public_doc")
def downgrade() -> None:
op.add_column(
"credential",
sa.Column("public_doc", sa.Boolean(), nullable=True),
)
# setting public_doc to false for all existing rows to be safe
# NOTE: this is likely not the correct state of the world but it's the best we can do
op.execute("UPDATE credential SET public_doc = false WHERE public_doc IS NULL")
op.alter_column("credential", "public_doc", nullable=False)
op.drop_column("connector_credential_pair", "is_public")
op.drop_column("credential", "is_admin")

View File

@ -287,13 +287,9 @@ def _run_indexing(
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
)
index_user_id = (
None if db_credential.public_doc else db_credential.user_id
)
new_docs, total_batch_chunks = indexing_pipeline(
documents=doc_batch,
index_attempt_metadata=IndexAttemptMetadata(
user_id=index_user_id,
connector_id=db_connector.id,
credential_id=db_credential.id,
),

View File

@ -130,7 +130,7 @@ def build_service_account_creds(
return CredentialBase(
credential_json=credential_dict,
public_doc=True,
is_admin=True,
)

View File

@ -1,7 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any
from uuid import UUID
from danswer.configs.constants import DocumentSource
@ -41,6 +40,5 @@ class InputType(str, Enum):
@dataclass
class IndexAttemptMetadata:
user_id: UUID | None
connector_id: int
credential_id: int

View File

@ -1,5 +1,6 @@
from typing import Any
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import or_
@ -19,18 +20,30 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def _attach_user_filters(stmt: Select[tuple[Credential]], user: User | None) -> Select:
"""Attaches filters to the statement to ensure that the user can only
access the appropriate credentials"""
if user:
if user.role == UserRole.ADMIN:
stmt = stmt.where(
or_(
Credential.user_id == user.id,
Credential.user_id.is_(None),
Credential.is_admin == True, # noqa: E712
)
)
else:
stmt = stmt.where(Credential.user_id == user.id)
return stmt
def fetch_credentials(
db_session: Session,
user: User | None = None,
public_only: bool | None = None,
) -> list[Credential]:
stmt = select(Credential)
if user:
stmt = stmt.where(
or_(Credential.user_id == user.id, Credential.user_id.is_(None))
)
if public_only is not None:
stmt = stmt.where(Credential.public_doc == public_only)
stmt = _attach_user_filters(stmt, user)
results = db_session.scalars(stmt)
return list(results.all())
@ -39,20 +52,7 @@ def fetch_credential_by_id(
credential_id: int, user: User | None, db_session: Session
) -> Credential | None:
stmt = select(Credential).where(Credential.id == credential_id)
if user:
# admins have access to all public credentials + credentials they own
if user.role == UserRole.ADMIN:
stmt = stmt.where(
or_(
Credential.user_id == user.id,
Credential.user_id.is_(None),
Credential.public_doc == True, # noqa: E712
)
)
else:
stmt = stmt.where(
or_(Credential.user_id == user.id, Credential.user_id.is_(None))
)
stmt = _attach_user_filters(stmt, user)
result = db_session.execute(stmt)
credential = result.scalar_one_or_none()
return credential
@ -60,13 +60,13 @@ def fetch_credential_by_id(
def create_credential(
credential_data: CredentialBase,
user: User,
user: User | None,
db_session: Session,
) -> ObjectCreationIdResponse:
credential = Credential(
credential_json=credential_data.credential_json,
user_id=user.id if user else None,
public_doc=credential_data.public_doc,
is_admin=credential_data.is_admin,
)
db_session.add(credential)
db_session.commit()
@ -86,7 +86,6 @@ def update_credential(
credential.credential_json = credential_data.credential_json
credential.user_id = user.id if user is not None else None
credential.public_doc = credential_data.public_doc
db_session.commit()
return credential
@ -144,13 +143,15 @@ def create_initial_public_credential() -> None:
if first_credential is not None:
if (
first_credential.credential_json != {}
or first_credential.public_doc is False
or first_credential.user is not None
):
raise ValueError(error_msg)
return
credential = Credential(
id=public_cred_id, credential_json={}, user_id=None, public_doc=True
id=public_cred_id,
credential_json={},
user_id=None,
)
db_session.add(credential)
db_session.commit()

View File

@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import DEFAULT_BOOST
from danswer.datastores.interfaces import DocumentMetadata
from danswer.db.feedback import delete_document_feedback_for_documents
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentByConnectorCredentialPair
@ -69,7 +70,7 @@ def get_acccess_info_for_documents(
stmt = select(
DocumentByConnectorCredentialPair.id,
func.array_agg(Credential.user_id).label("user_ids"),
func.bool_or(Credential.public_doc).label("public_doc"),
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
# pretend that the specified cc pair doesn't exist
@ -83,10 +84,22 @@ def get_acccess_info_for_documents(
)
)
stmt = stmt.join(
Credential,
DocumentByConnectorCredentialPair.credential_id == Credential.id,
).group_by(DocumentByConnectorCredentialPair.id)
stmt = (
stmt.join(
Credential,
DocumentByConnectorCredentialPair.credential_id == Credential.id,
)
.join(
ConnectorCredentialPair,
and_(
DocumentByConnectorCredentialPair.connector_id
== ConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id
== ConnectorCredentialPair.credential_id,
),
)
.group_by(DocumentByConnectorCredentialPair.id)
)
return db_session.execute(stmt).all() # type: ignore

View File

@ -144,6 +144,14 @@ class ConnectorCredentialPair(Base):
credential_id: Mapped[int] = mapped_column(
ForeignKey("credential.id"), primary_key=True
)
# controls whether the documents indexed by this CC pair are visible to all
# or if they are only visible to those with that are given explicit access
# (e.g. via owning the credential or being a part of a group that is given access)
is_public: Mapped[bool] = mapped_column(
Boolean,
default=True,
nullable=False,
)
# Time finished, not used for calculating backend jobs which uses time started (created)
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
@ -206,7 +214,8 @@ class Credential(Base):
id: Mapped[int] = mapped_column(primary_key=True)
credential_json: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB())
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
public_doc: Mapped[bool] = mapped_column(Boolean, default=False)
# if `true`, then all Admins will have access to the credential
is_admin: Mapped[bool] = mapped_column(Boolean, default=True)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)

View File

@ -3,6 +3,7 @@ from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.credentials import create_credential
@ -26,11 +27,11 @@ router = APIRouter(prefix="/manage")
@router.get("/admin/credential")
def list_credentials_admin(
_: User = Depends(current_admin_user),
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[CredentialSnapshot]:
"""Lists all public credentials"""
credentials = fetch_credentials(db_session=db_session, public_only=True)
credentials = fetch_credentials(db_session=db_session, user=user)
return [
CredentialSnapshot.from_credential_db_model(credential)
for credential in credentials
@ -65,6 +66,21 @@ def list_credentials(
]
@router.post("/credential")
def create_credential_from_model(
credential_info: CredentialBase,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ObjectCreationIdResponse:
if user and user.role != UserRole.ADMIN:
raise HTTPException(
status_code=400,
detail="Non-admin cannot create admin credential",
)
return create_credential(credential_info, user, db_session)
@router.get("/credential/{credential_id}")
def get_credential_by_id(
credential_id: int,
@ -81,15 +97,6 @@ def get_credential_by_id(
return CredentialSnapshot.from_credential_db_model(credential)
@router.post("/credential")
def create_credential_from_model(
connector_info: CredentialBase,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ObjectCreationIdResponse:
return create_credential(connector_info, user, db_session)
@router.patch("/credential/{credential_id}")
def update_credential_from_model(
credential_id: int,
@ -110,7 +117,7 @@ def update_credential_from_model(
id=updated_credential.id,
credential_json=updated_credential.credential_json,
user_id=updated_credential.user_id,
public_doc=updated_credential.public_doc,
is_admin=updated_credential.is_admin,
time_created=updated_credential.time_created,
time_updated=updated_credential.time_updated,
)

View File

@ -215,7 +215,7 @@ def delete_google_service_account_key(
@router.put("/admin/connector/google-drive/service-account-credential")
def upsert_service_account_credential(
service_account_credential_request: GoogleServiceAccountCredentialRequest,
user: User = Depends(current_admin_user),
user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ObjectCreationIdResponse:
"""Special API which allows the creation of a credential for a service account.
@ -225,12 +225,12 @@ def upsert_service_account_credential(
credential_base = build_service_account_creds(
delegated_user_email=service_account_credential_request.google_drive_delegated_user
)
print(credential_base)
except ConfigNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
# first delete all existing service account credentials
delete_google_drive_service_account_credentials(user, db_session)
# `user=None` since this credential is not a personal credential
return create_credential(
credential_data=credential_base, user=user, db_session=db_session
)
@ -322,7 +322,7 @@ def get_connector_indexing_status(
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(connector),
credential=CredentialSnapshot.from_credential_db_model(credential),
public_doc=credential.public_doc,
public_doc=cc_pair.is_public,
owner=credential.user.email if credential.user else "",
last_status=cc_pair.last_attempt_status,
last_success=cc_pair.last_successful_index_time,

View File

@ -336,7 +336,7 @@ class RunConnectorRequest(BaseModel):
class CredentialBase(BaseModel):
credential_json: dict[str, Any]
public_doc: bool
is_admin: bool
class CredentialSnapshot(CredentialBase):
@ -353,7 +353,7 @@ class CredentialSnapshot(CredentialBase):
if MASK_CREDENTIAL_PREFIX
else credential.credential_json,
user_id=credential.user_id,
public_doc=credential.public_doc,
is_admin=credential.is_admin,
time_created=credential.time_created,
time_updated=credential.time_updated,
)

View File

@ -321,7 +321,8 @@ const Main = () => {
| Credential<GoogleDriveCredentialJson>
| undefined = credentialsData.find(
(credential) =>
credential.credential_json?.google_drive_tokens && credential.public_doc
credential.credential_json?.google_drive_tokens &&
credential.user_id === null
);
const googleDriveServiceAccountCredential:
| Credential<GoogleDriveServiceAccountCredentialJson>

View File

@ -152,7 +152,7 @@ function Main() {
<ConnectorTitle
ccPairName={connectorIndexingStatus.name}
connector={connectorIndexingStatus.connector}
isPublic={connectorIndexingStatus.credential.public_doc}
isPublic={connectorIndexingStatus.public_doc}
owner={connectorIndexingStatus.owner}
/>
),

View File

@ -20,8 +20,9 @@ export const GoogleDriveCard = ({
const existingCredential: Credential<GoogleDriveCredentialJson> | undefined =
userCredentials?.find(
(credential) =>
// user_id is set => credential is not a public credential
credential.credential_json?.google_drive_tokens !== undefined &&
!credential.public_doc
credential.user_id !== null
);
const credentialIsLinked =

View File

@ -57,7 +57,7 @@ export function CredentialForm<T extends Yup.AnyObject>({
formikHelpers.setSubmitting(true);
submitCredential<T>({
credential_json: values,
public_doc: true,
is_admin: true,
}).then(({ message, isSuccess }) => {
setPopup({ message, type: isSuccess ? "success" : "error" });
formikHelpers.setSubmitting(false);

View File

@ -139,7 +139,7 @@ export interface ConnectorIndexingStatus<
// CREDENTIALS
export interface CredentialBase<T> {
credential_json: T;
public_doc: boolean;
is_admin: boolean;
}
export interface Credential<T> extends CredentialBase<T> {