mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 19:38:19 +02:00
Move is_public from Credential to ConnectorCredentialPair (#523)
This commit is contained in:
parent
a85e73edbe
commit
9c89ae78ba
@ -0,0 +1,49 @@
|
||||
"""Move is_public to cc_pair
|
||||
|
||||
Revision ID: 3b25685ff73c
|
||||
Revises: e0a68a81d434
|
||||
Create Date: 2023-10-05 18:47:09.582849
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b25685ff73c"
|
||||
down_revision = "e0a68a81d434"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=True),
|
||||
)
|
||||
# fill in is_public for existing rows
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_public = true WHERE is_public IS NULL"
|
||||
)
|
||||
op.alter_column("connector_credential_pair", "is_public", nullable=False)
|
||||
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column("is_admin", sa.Boolean(), nullable=True),
|
||||
)
|
||||
op.execute("UPDATE credential SET is_admin = true WHERE is_admin IS NULL")
|
||||
op.alter_column("credential", "is_admin", nullable=False)
|
||||
|
||||
op.drop_column("credential", "public_doc")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column("public_doc", sa.Boolean(), nullable=True),
|
||||
)
|
||||
# setting public_doc to false for all existing rows to be safe
|
||||
# NOTE: this is likely not the correct state of the world but it's the best we can do
|
||||
op.execute("UPDATE credential SET public_doc = false WHERE public_doc IS NULL")
|
||||
op.alter_column("credential", "public_doc", nullable=False)
|
||||
op.drop_column("connector_credential_pair", "is_public")
|
||||
op.drop_column("credential", "is_admin")
|
@ -287,13 +287,9 @@ def _run_indexing(
|
||||
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
|
||||
)
|
||||
|
||||
index_user_id = (
|
||||
None if db_credential.public_doc else db_credential.user_id
|
||||
)
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
documents=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
user_id=index_user_id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
),
|
||||
|
@ -130,7 +130,7 @@ def build_service_account_creds(
|
||||
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
public_doc=True,
|
||||
is_admin=True,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
@ -41,6 +40,5 @@ class InputType(str, Enum):
|
||||
|
||||
@dataclass
|
||||
class IndexAttemptMetadata:
|
||||
user_id: UUID | None
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import or_
|
||||
@ -19,18 +20,30 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _attach_user_filters(stmt: Select[tuple[Credential]], user: User | None) -> Select:
|
||||
"""Attaches filters to the statement to ensure that the user can only
|
||||
access the appropriate credentials"""
|
||||
if user:
|
||||
if user.role == UserRole.ADMIN:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Credential.user_id == user.id,
|
||||
Credential.user_id.is_(None),
|
||||
Credential.is_admin == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = stmt.where(Credential.user_id == user.id)
|
||||
|
||||
return stmt
|
||||
|
||||
|
||||
def fetch_credentials(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
public_only: bool | None = None,
|
||||
) -> list[Credential]:
|
||||
stmt = select(Credential)
|
||||
if user:
|
||||
stmt = stmt.where(
|
||||
or_(Credential.user_id == user.id, Credential.user_id.is_(None))
|
||||
)
|
||||
if public_only is not None:
|
||||
stmt = stmt.where(Credential.public_doc == public_only)
|
||||
stmt = _attach_user_filters(stmt, user)
|
||||
results = db_session.scalars(stmt)
|
||||
return list(results.all())
|
||||
|
||||
@ -39,20 +52,7 @@ def fetch_credential_by_id(
|
||||
credential_id: int, user: User | None, db_session: Session
|
||||
) -> Credential | None:
|
||||
stmt = select(Credential).where(Credential.id == credential_id)
|
||||
if user:
|
||||
# admins have access to all public credentials + credentials they own
|
||||
if user.role == UserRole.ADMIN:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
Credential.user_id == user.id,
|
||||
Credential.user_id.is_(None),
|
||||
Credential.public_doc == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
else:
|
||||
stmt = stmt.where(
|
||||
or_(Credential.user_id == user.id, Credential.user_id.is_(None))
|
||||
)
|
||||
stmt = _attach_user_filters(stmt, user)
|
||||
result = db_session.execute(stmt)
|
||||
credential = result.scalar_one_or_none()
|
||||
return credential
|
||||
@ -60,13 +60,13 @@ def fetch_credential_by_id(
|
||||
|
||||
def create_credential(
|
||||
credential_data: CredentialBase,
|
||||
user: User,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> ObjectCreationIdResponse:
|
||||
credential = Credential(
|
||||
credential_json=credential_data.credential_json,
|
||||
user_id=user.id if user else None,
|
||||
public_doc=credential_data.public_doc,
|
||||
is_admin=credential_data.is_admin,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
@ -86,7 +86,6 @@ def update_credential(
|
||||
|
||||
credential.credential_json = credential_data.credential_json
|
||||
credential.user_id = user.id if user is not None else None
|
||||
credential.public_doc = credential_data.public_doc
|
||||
|
||||
db_session.commit()
|
||||
return credential
|
||||
@ -144,13 +143,15 @@ def create_initial_public_credential() -> None:
|
||||
if first_credential is not None:
|
||||
if (
|
||||
first_credential.credential_json != {}
|
||||
or first_credential.public_doc is False
|
||||
or first_credential.user is not None
|
||||
):
|
||||
raise ValueError(error_msg)
|
||||
return
|
||||
|
||||
credential = Credential(
|
||||
id=public_cred_id, credential_json={}, user_id=None, public_doc=True
|
||||
id=public_cred_id,
|
||||
credential_json={},
|
||||
user_id=None,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
|
@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.datastores.interfaces import DocumentMetadata
|
||||
from danswer.db.feedback import delete_document_feedback_for_documents
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
@ -69,7 +70,7 @@ def get_acccess_info_for_documents(
|
||||
stmt = select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.array_agg(Credential.user_id).label("user_ids"),
|
||||
func.bool_or(Credential.public_doc).label("public_doc"),
|
||||
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
|
||||
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
|
||||
# pretend that the specified cc pair doesn't exist
|
||||
@ -83,10 +84,22 @@ def get_acccess_info_for_documents(
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.join(
|
||||
Credential,
|
||||
DocumentByConnectorCredentialPair.credential_id == Credential.id,
|
||||
).group_by(DocumentByConnectorCredentialPair.id)
|
||||
stmt = (
|
||||
stmt.join(
|
||||
Credential,
|
||||
DocumentByConnectorCredentialPair.credential_id == Credential.id,
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.group_by(DocumentByConnectorCredentialPair.id)
|
||||
)
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
|
@ -144,6 +144,14 @@ class ConnectorCredentialPair(Base):
|
||||
credential_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("credential.id"), primary_key=True
|
||||
)
|
||||
# controls whether the documents indexed by this CC pair are visible to all
|
||||
# or if they are only visible to those with that are given explicit access
|
||||
# (e.g. via owning the credential or being a part of a group that is given access)
|
||||
is_public: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=True,
|
||||
nullable=False,
|
||||
)
|
||||
# Time finished, not used for calculating backend jobs which uses time started (created)
|
||||
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
@ -206,7 +214,8 @@ class Credential(Base):
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB())
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
public_doc: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# if `true`, then all Admins will have access to the credential
|
||||
is_admin: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
@ -3,6 +3,7 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.credentials import create_credential
|
||||
@ -26,11 +27,11 @@ router = APIRouter(prefix="/manage")
|
||||
|
||||
@router.get("/admin/credential")
|
||||
def list_credentials_admin(
|
||||
_: User = Depends(current_admin_user),
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[CredentialSnapshot]:
|
||||
"""Lists all public credentials"""
|
||||
credentials = fetch_credentials(db_session=db_session, public_only=True)
|
||||
credentials = fetch_credentials(db_session=db_session, user=user)
|
||||
return [
|
||||
CredentialSnapshot.from_credential_db_model(credential)
|
||||
for credential in credentials
|
||||
@ -65,6 +66,21 @@ def list_credentials(
|
||||
]
|
||||
|
||||
|
||||
@router.post("/credential")
|
||||
def create_credential_from_model(
|
||||
credential_info: CredentialBase,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ObjectCreationIdResponse:
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Non-admin cannot create admin credential",
|
||||
)
|
||||
|
||||
return create_credential(credential_info, user, db_session)
|
||||
|
||||
|
||||
@router.get("/credential/{credential_id}")
|
||||
def get_credential_by_id(
|
||||
credential_id: int,
|
||||
@ -81,15 +97,6 @@ def get_credential_by_id(
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
|
||||
|
||||
@router.post("/credential")
|
||||
def create_credential_from_model(
|
||||
connector_info: CredentialBase,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ObjectCreationIdResponse:
|
||||
return create_credential(connector_info, user, db_session)
|
||||
|
||||
|
||||
@router.patch("/credential/{credential_id}")
|
||||
def update_credential_from_model(
|
||||
credential_id: int,
|
||||
@ -110,7 +117,7 @@ def update_credential_from_model(
|
||||
id=updated_credential.id,
|
||||
credential_json=updated_credential.credential_json,
|
||||
user_id=updated_credential.user_id,
|
||||
public_doc=updated_credential.public_doc,
|
||||
is_admin=updated_credential.is_admin,
|
||||
time_created=updated_credential.time_created,
|
||||
time_updated=updated_credential.time_updated,
|
||||
)
|
||||
|
@ -215,7 +215,7 @@ def delete_google_service_account_key(
|
||||
@router.put("/admin/connector/google-drive/service-account-credential")
|
||||
def upsert_service_account_credential(
|
||||
service_account_credential_request: GoogleServiceAccountCredentialRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ObjectCreationIdResponse:
|
||||
"""Special API which allows the creation of a credential for a service account.
|
||||
@ -225,12 +225,12 @@ def upsert_service_account_credential(
|
||||
credential_base = build_service_account_creds(
|
||||
delegated_user_email=service_account_credential_request.google_drive_delegated_user
|
||||
)
|
||||
print(credential_base)
|
||||
except ConfigNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# first delete all existing service account credentials
|
||||
delete_google_drive_service_account_credentials(user, db_session)
|
||||
# `user=None` since this credential is not a personal credential
|
||||
return create_credential(
|
||||
credential_data=credential_base, user=user, db_session=db_session
|
||||
)
|
||||
@ -322,7 +322,7 @@ def get_connector_indexing_status(
|
||||
name=cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(connector),
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
public_doc=credential.public_doc,
|
||||
public_doc=cc_pair.is_public,
|
||||
owner=credential.user.email if credential.user else "",
|
||||
last_status=cc_pair.last_attempt_status,
|
||||
last_success=cc_pair.last_successful_index_time,
|
||||
|
@ -336,7 +336,7 @@ class RunConnectorRequest(BaseModel):
|
||||
|
||||
class CredentialBase(BaseModel):
|
||||
credential_json: dict[str, Any]
|
||||
public_doc: bool
|
||||
is_admin: bool
|
||||
|
||||
|
||||
class CredentialSnapshot(CredentialBase):
|
||||
@ -353,7 +353,7 @@ class CredentialSnapshot(CredentialBase):
|
||||
if MASK_CREDENTIAL_PREFIX
|
||||
else credential.credential_json,
|
||||
user_id=credential.user_id,
|
||||
public_doc=credential.public_doc,
|
||||
is_admin=credential.is_admin,
|
||||
time_created=credential.time_created,
|
||||
time_updated=credential.time_updated,
|
||||
)
|
||||
|
@ -321,7 +321,8 @@ const Main = () => {
|
||||
| Credential<GoogleDriveCredentialJson>
|
||||
| undefined = credentialsData.find(
|
||||
(credential) =>
|
||||
credential.credential_json?.google_drive_tokens && credential.public_doc
|
||||
credential.credential_json?.google_drive_tokens &&
|
||||
credential.user_id === null
|
||||
);
|
||||
const googleDriveServiceAccountCredential:
|
||||
| Credential<GoogleDriveServiceAccountCredentialJson>
|
||||
|
@ -152,7 +152,7 @@ function Main() {
|
||||
<ConnectorTitle
|
||||
ccPairName={connectorIndexingStatus.name}
|
||||
connector={connectorIndexingStatus.connector}
|
||||
isPublic={connectorIndexingStatus.credential.public_doc}
|
||||
isPublic={connectorIndexingStatus.public_doc}
|
||||
owner={connectorIndexingStatus.owner}
|
||||
/>
|
||||
),
|
||||
|
@ -20,8 +20,9 @@ export const GoogleDriveCard = ({
|
||||
const existingCredential: Credential<GoogleDriveCredentialJson> | undefined =
|
||||
userCredentials?.find(
|
||||
(credential) =>
|
||||
// user_id is set => credential is not a public credential
|
||||
credential.credential_json?.google_drive_tokens !== undefined &&
|
||||
!credential.public_doc
|
||||
credential.user_id !== null
|
||||
);
|
||||
|
||||
const credentialIsLinked =
|
||||
|
@ -57,7 +57,7 @@ export function CredentialForm<T extends Yup.AnyObject>({
|
||||
formikHelpers.setSubmitting(true);
|
||||
submitCredential<T>({
|
||||
credential_json: values,
|
||||
public_doc: true,
|
||||
is_admin: true,
|
||||
}).then(({ message, isSuccess }) => {
|
||||
setPopup({ message, type: isSuccess ? "success" : "error" });
|
||||
formikHelpers.setSubmitting(false);
|
||||
|
@ -139,7 +139,7 @@ export interface ConnectorIndexingStatus<
|
||||
// CREDENTIALS
|
||||
export interface CredentialBase<T> {
|
||||
credential_json: T;
|
||||
public_doc: boolean;
|
||||
is_admin: boolean;
|
||||
}
|
||||
|
||||
export interface Credential<T> extends CredentialBase<T> {
|
||||
|
Loading…
x
Reference in New Issue
Block a user