mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 20:39:29 +02:00
Integration tests (#2256)
* initial commit * almost done * finished 3 tests * minor refactor * built out initial permisison tests * reworked test_deletion * removed logging * all original tests have been converted * renamed user_groups to user_group * mypy * added test for doc set permissions * unified naming for manager methods * Refactored models and added new deletion test * minor additions * better logging+fixed input variables * commented out failed tests * Added readme * readme update * Added auth to IT set auth_type to basic and require_email_verification to false * Update run-it.yml * used verify and added to readme * added api key manager
This commit is contained in:
parent
634de83d72
commit
8d443ada5b
2
.github/workflows/run-it.yml
vendored
2
.github/workflows/run-it.yml
vendored
@ -92,6 +92,8 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
IMAGE_TAG=it \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
|
||||
id: start_docker
|
||||
|
@ -67,23 +67,6 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validate_curator_request(groups: list | None, is_public: bool) -> None:
|
||||
if is_public:
|
||||
detail = "Curators cannot create public objects"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=detail,
|
||||
)
|
||||
if not groups:
|
||||
detail = "Curators must specify 1+ groups"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
|
@ -334,9 +334,13 @@ def add_credential_to_connector(
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if credential is None:
|
||||
error_msg = (
|
||||
f"Credential {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Credential does not exist or does not belong to user",
|
||||
detail=error_msg,
|
||||
)
|
||||
|
||||
existing_association = (
|
||||
@ -350,7 +354,7 @@ def add_credential_to_connector(
|
||||
if existing_association is not None:
|
||||
return StatusResponse(
|
||||
success=False,
|
||||
message=f"Connector already has Credential {credential_id}",
|
||||
message=f"Connector {connector_id} already has Credential {credential_id}",
|
||||
data=connector_id,
|
||||
)
|
||||
|
||||
@ -374,8 +378,8 @@ def add_credential_to_connector(
|
||||
db_session.commit()
|
||||
|
||||
return StatusResponse(
|
||||
success=False,
|
||||
message=f"Connector already has Credential {credential_id}",
|
||||
success=True,
|
||||
message=f"Creating new association between Connector {connector_id} and Credential {credential_id}",
|
||||
data=association.id,
|
||||
)
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -21,12 +20,13 @@ from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import get_index_attempts_for_connector
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -84,10 +84,6 @@ def get_cc_pair_full_info(
|
||||
)
|
||||
|
||||
|
||||
class CCStatusUpdateRequest(BaseModel):
|
||||
status: ConnectorCredentialPairStatus
|
||||
|
||||
|
||||
@router.put("/admin/cc-pair/{cc_pair_id}/status")
|
||||
def update_cc_pair_status(
|
||||
cc_pair_id: int,
|
||||
@ -157,11 +153,12 @@ def associate_credential_to_connector(
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[int]:
|
||||
if user and user.role != UserRole.ADMIN and metadata.is_public:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Public connections cannot be created by non-admin users",
|
||||
)
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=metadata.groups,
|
||||
object_is_public=metadata.is_public,
|
||||
)
|
||||
|
||||
try:
|
||||
response = add_credential_to_connector(
|
||||
|
@ -75,7 +75,6 @@ from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.server.documents.models import AuthStatus
|
||||
from danswer.server.documents.models import AuthUrl
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorIndexingStatus
|
||||
from danswer.server.documents.models import ConnectorSnapshot
|
||||
@ -93,6 +92,7 @@ from danswer.server.documents.models import ObjectCreationIdResponse
|
||||
from danswer.server.documents.models import RunConnectorRequest
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -514,35 +514,6 @@ def _validate_connector_allowed(source: DocumentSource) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _check_connector_permissions(
|
||||
connector_data: ConnectorUpdateRequest, user: User | None
|
||||
) -> ConnectorBase:
|
||||
"""
|
||||
This is not a proper permission check, but this should prevent curators creating bad situations
|
||||
until a long-term solution is implemented (Replacing CC pairs/Connectors with Connections)
|
||||
"""
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
if connector_data.is_public:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Public connectors can only be created by admins",
|
||||
)
|
||||
if not connector_data.groups:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connectors created by curators must have groups",
|
||||
)
|
||||
return ConnectorBase(
|
||||
name=connector_data.name,
|
||||
source=connector_data.source,
|
||||
input_type=connector_data.input_type,
|
||||
connector_specific_config=connector_data.connector_specific_config,
|
||||
refresh_freq=connector_data.refresh_freq,
|
||||
prune_freq=connector_data.prune_freq,
|
||||
indexing_start=connector_data.indexing_start,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/connector")
|
||||
def create_connector_from_model(
|
||||
connector_data: ConnectorUpdateRequest,
|
||||
@ -551,13 +522,19 @@ def create_connector_from_model(
|
||||
) -> ObjectCreationIdResponse:
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
connector_base = _check_connector_permissions(connector_data, user)
|
||||
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=connector_data.groups,
|
||||
object_is_public=connector_data.is_public,
|
||||
)
|
||||
connector_base = connector_data.to_connector_base()
|
||||
return create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_base,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error creating connector: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@ -608,12 +585,18 @@ def create_connector_with_mock_credential(
|
||||
def update_connector_from_model(
|
||||
connector_id: int,
|
||||
connector_data: ConnectorUpdateRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ConnectorSnapshot | StatusResponse[int]:
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
connector_base = _check_connector_permissions(connector_data, user)
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=connector_data.groups,
|
||||
object_is_public=connector_data.is_public,
|
||||
)
|
||||
connector_base = connector_data.to_connector_base()
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@ -643,7 +626,7 @@ def update_connector_from_model(
|
||||
@router.delete("/admin/connector/{connector_id}", response_model=StatusResponse[int])
|
||||
def delete_connector_by_id(
|
||||
connector_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[int]:
|
||||
try:
|
||||
|
@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import validate_curator_request
|
||||
from danswer.db.credentials import alter_credential
|
||||
from danswer.db.credentials import create_credential
|
||||
from danswer.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
|
||||
@ -20,7 +19,6 @@ from danswer.db.credentials import update_credential
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import DocumentSource
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialDataUpdateRequest
|
||||
from danswer.server.documents.models import CredentialSnapshot
|
||||
@ -28,6 +26,7 @@ from danswer.server.documents.models import CredentialSwapRequest
|
||||
from danswer.server.documents.models import ObjectCreationIdResponse
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -80,7 +79,7 @@ def get_cc_source_full_info(
|
||||
]
|
||||
|
||||
|
||||
@router.get("/credentials/{id}")
|
||||
@router.get("/credential/{id}")
|
||||
def list_credentials_by_id(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@ -105,7 +104,7 @@ def delete_credential_by_id_admin(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/admin/credentials/swap")
|
||||
@router.put("/admin/credential/swap")
|
||||
def swap_credentials_for_connector(
|
||||
credential_swap_req: CredentialSwapRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
@ -131,14 +130,12 @@ def create_credential_from_model(
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ObjectCreationIdResponse:
|
||||
if (
|
||||
user
|
||||
and user.role != UserRole.ADMIN
|
||||
and not _ignore_credential_permissions(credential_info.source)
|
||||
):
|
||||
validate_curator_request(
|
||||
groups=credential_info.groups,
|
||||
is_public=credential_info.curator_public,
|
||||
if not _ignore_credential_permissions(credential_info.source):
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=credential_info.groups,
|
||||
object_is_public=credential_info.curator_public,
|
||||
)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
@ -179,7 +176,7 @@ def get_credential_by_id(
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
|
||||
|
||||
@router.put("/admin/credentials/{credential_id}")
|
||||
@router.put("/admin/credential/{credential_id}")
|
||||
def update_credential_data(
|
||||
credential_id: int,
|
||||
credential_update: CredentialDataUpdateRequest,
|
||||
|
@ -48,9 +48,12 @@ class ConnectorBase(BaseModel):
|
||||
|
||||
|
||||
class ConnectorUpdateRequest(ConnectorBase):
|
||||
is_public: bool | None = None
|
||||
is_public: bool = True
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
def to_connector_base(self) -> ConnectorBase:
|
||||
return ConnectorBase(**self.model_dump(exclude={"is_public", "groups"}))
|
||||
|
||||
|
||||
class ConnectorSnapshot(ConnectorBase):
|
||||
id: int
|
||||
@ -103,11 +106,6 @@ class CredentialSnapshot(CredentialBase):
|
||||
user_id: UUID | None
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
name: str | None
|
||||
source: DocumentSource
|
||||
credential_json: dict[str, Any]
|
||||
admin_public: bool
|
||||
curator_public: bool
|
||||
|
||||
@classmethod
|
||||
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
|
||||
@ -261,6 +259,10 @@ class ConnectorCredentialPairMetadata(BaseModel):
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class CCStatusUpdateRequest(BaseModel):
|
||||
status: ConnectorCredentialPairStatus
|
||||
|
||||
|
||||
class ConnectorCredentialPairDescriptor(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
|
@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import validate_curator_request
|
||||
from danswer.db.document_set import check_document_sets_are_public
|
||||
from danswer.db.document_set import fetch_all_document_sets_for_user
|
||||
from danswer.db.document_set import insert_document_set
|
||||
@ -14,12 +13,12 @@ from danswer.db.document_set import mark_document_set_as_to_be_deleted
|
||||
from danswer.db.document_set import update_document_set
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.features.document_set.models import CheckDocSetPublicRequest
|
||||
from danswer.server.features.document_set.models import CheckDocSetPublicResponse
|
||||
from danswer.server.features.document_set.models import DocumentSet
|
||||
from danswer.server.features.document_set.models import DocumentSetCreationRequest
|
||||
from danswer.server.features.document_set.models import DocumentSetUpdateRequest
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
@ -31,11 +30,12 @@ def create_document_set(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> int:
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
validate_curator_request(
|
||||
groups=document_set_creation_request.groups,
|
||||
is_public=document_set_creation_request.is_public,
|
||||
)
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=document_set_creation_request.groups,
|
||||
object_is_public=document_set_creation_request.is_public,
|
||||
)
|
||||
try:
|
||||
document_set_db_model, _ = insert_document_set(
|
||||
document_set_creation_request=document_set_creation_request,
|
||||
@ -53,11 +53,12 @@ def patch_document_set(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
validate_curator_request(
|
||||
groups=document_set_update_request.groups,
|
||||
is_public=document_set_update_request.is_public,
|
||||
)
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=document_set_update_request.groups,
|
||||
object_is_public=document_set_update_request.is_public,
|
||||
)
|
||||
try:
|
||||
update_document_set(
|
||||
document_set_update_request=document_set_update_request,
|
||||
|
@ -69,7 +69,7 @@ def set_user_role(
|
||||
|
||||
if user_role_update_request.new_role == UserRole.CURATOR:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
status_code=402,
|
||||
detail="Curator role must be set via the User Group Menu",
|
||||
)
|
||||
|
||||
@ -78,7 +78,7 @@ def set_user_role(
|
||||
|
||||
if current_user.id == user_to_update.id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
status_code=402,
|
||||
detail="An admin cannot demote themselves from admin role!",
|
||||
)
|
||||
|
||||
|
@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||
from operator import and_
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
@ -30,6 +31,50 @@ from ee.danswer.server.user_group.models import UserGroupUpdate
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validate_user_creation_permissions(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
target_group_ids: list[int] | None,
|
||||
object_is_public: bool | None,
|
||||
) -> None:
|
||||
"""
|
||||
All admin actions are allowed.
|
||||
Prevents non-admins from creating/editing:
|
||||
- public objects
|
||||
- objects with no groups
|
||||
- objects that belong to a group they don't curate
|
||||
"""
|
||||
if not user or user.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
if object_is_public:
|
||||
detail = "User does not have permission to create public credentials"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=detail,
|
||||
)
|
||||
if not target_group_ids:
|
||||
detail = "Curators must specify 1+ groups"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=detail,
|
||||
)
|
||||
user_curated_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session, user_id=user.id, only_curator_groups=True
|
||||
)
|
||||
user_curated_group_ids = set([group.id for group in user_curated_groups])
|
||||
target_group_ids_set = set(target_group_ids)
|
||||
if not target_group_ids_set.issubset(user_curated_group_ids):
|
||||
detail = "Curators cannot control groups they don't curate"
|
||||
logger.error(detail)
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None:
|
||||
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
|
||||
return db_session.scalar(stmt)
|
||||
|
@ -9,6 +9,7 @@ from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.user_group import fetch_user_groups
|
||||
from ee.danswer.db.user_group import fetch_user_groups_for_user
|
||||
from ee.danswer.db.user_group import insert_user_group
|
||||
@ -20,6 +21,8 @@ from ee.danswer.server.user_group.models import UserGroup
|
||||
from ee.danswer.server.user_group.models import UserGroupCreate
|
||||
from ee.danswer.server.user_group.models import UserGroupUpdate
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
|
||||
|
||||
@ -90,6 +93,7 @@ def set_user_curator(
|
||||
set_curator_request=set_curator_request,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error setting user curator: {e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
|
70
backend/tests/integration/README.md
Normal file
70
backend/tests/integration/README.md
Normal file
@ -0,0 +1,70 @@
|
||||
# Integration Tests
|
||||
|
||||
## General Testing Overview
|
||||
The integration tests are designed with a "manager" class and a "test" class for each type of object being manipulated (e.g., user, persona, credential):
|
||||
- **Manager Class**: Contains methods for each type of API call. Responsible for creating, deleting, and verifying the existence of an entity.
|
||||
- **Test Class**: Stores data for each entity being tested. This is our "expected state" of the object.
|
||||
|
||||
The idea is that each test can use the manager class to create (.create()) a "test_" object. It can then perform an operation on the object (e.g., send a request to the API) and then check if the "test_" object is in the expected state by using the manager class (.verify()) function.
|
||||
|
||||
## Instructions for Running Integration Tests Locally
|
||||
1. Launch danswer (using Docker or running with a debugger), ensuring the API server is running on port 8080.
|
||||
a. If you'd like to set environment variables, you can do so by creating a `.env` file in the danswer/backend/tests/integration/ directory.
|
||||
2. Navigate to `danswer/backend`.
|
||||
3. Run the following command in the terminal:
|
||||
```sh
|
||||
pytest -s tests/integration/tests/
|
||||
```
|
||||
or to run all tests in a file:
|
||||
```sh
|
||||
pytest -s tests/integration/tests/path_to/test_file.py
|
||||
```
|
||||
or to run a single test:
|
||||
```sh
|
||||
pytest -s tests/integration/tests/path_to/test_file.py::test_function_name
|
||||
```
|
||||
|
||||
## Guidelines for Writing Integration Tests
|
||||
- As authentication is currently required for all tests, each test should start by creating a user.
|
||||
- Each test should ideally focus on a single API flow.
|
||||
- The test writer should try to consider failure cases and edge cases for the flow and write the tests to check for these cases.
|
||||
- Every step of the test should be commented describing what is being done and what the expected behavior is.
|
||||
- A summary of the test should be given at the top of the test function as well!
|
||||
- When writing new tests, manager classes, manager functions, and test classes, try to copy the style of the other ones that have already been written.
|
||||
- Be careful for scope creep!
|
||||
- No need to overcomplicate every test by verifying after every single API call so long as the case you would be verifying is covered elsewhere (ideally in a test focused on covering that case).
|
||||
- An example of this is: Creating an admin user is done at the beginning of nearly every test, but we only need to verify that the user is actually an admin in the test focused on checking admin permissions. For every other test, we can just create the admin user and assume that the permissions are working as expected.
|
||||
|
||||
## Current Testing Limitations
|
||||
### Test coverage
|
||||
- All tests are probably not as high coverage as they could be.
|
||||
- The "connector" tests in particular are super bare bones because we will be reworking connector/cc_pair sometime soon.
|
||||
- Global Curator role is not thoroughly tested.
|
||||
- No auth is not tested at all.
|
||||
### Failure checking
|
||||
- While we test expected auth failures, we only check that it failed at all.
|
||||
- We dont check that the return codes are what we expect.
|
||||
- This means that a test could be failing for a different reason than expected.
|
||||
- We should ensure that the proper codes are being returned for each failure case.
|
||||
- We should also query the db after each failure to ensure that the db is in the expected state.
|
||||
### Scope/focus
|
||||
- The tests may be scoped sub-optimally.
|
||||
- The scoping of each test may be overlapping.
|
||||
|
||||
## Current Testing Coverage
|
||||
The current testing coverage should be checked by reading the comments at the top of each test file.
|
||||
|
||||
|
||||
## TODO: Testing Coverage
|
||||
- Persona permissions testing
|
||||
- Read only (and/or basic) user permissions
|
||||
- Ensuring proper permission enforcement using the chat/doc_search endpoints
|
||||
- No auth
|
||||
|
||||
## Ideas for integration testing design
|
||||
### Combine the "test" and "manager" classes
|
||||
This could make test writing a bit cleaner by preventing test writers from having to pass around objects into functions that the objects have a 1:1 relationship with.
|
||||
|
||||
### Rework VespaClient
|
||||
Right now, its used a fixture and has to be passed around between manager classes.
|
||||
Could just be built where its used
|
@ -1,114 +0,0 @@
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
|
||||
|
||||
class ConnectorCreationDetails(BaseModel):
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
cc_pair_id: int
|
||||
|
||||
|
||||
class ConnectorClient:
|
||||
@staticmethod
|
||||
def create_connector(
|
||||
name_prefix: str = "test_connector", credential_id: int | None = None
|
||||
) -> ConnectorCreationDetails:
|
||||
unique_id = uuid.uuid4()
|
||||
|
||||
connector_name = f"{name_prefix}_{unique_id}"
|
||||
connector_data = {
|
||||
"name": connector_name,
|
||||
"source": DocumentSource.NOT_APPLICABLE,
|
||||
"input_type": "load_state",
|
||||
"connector_specific_config": {},
|
||||
"refresh_freq": 60,
|
||||
"disabled": True,
|
||||
}
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector",
|
||||
json=connector_data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
connector_id = response.json()["id"]
|
||||
|
||||
# associate the credential with the connector
|
||||
if not credential_id:
|
||||
print("ID not specified, creating new credential")
|
||||
# Create a new credential
|
||||
credential_data = {
|
||||
"credential_json": {},
|
||||
"admin_public": True,
|
||||
"source": DocumentSource.NOT_APPLICABLE,
|
||||
}
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/credential",
|
||||
json=credential_data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
credential_id = cast(int, response.json()["id"])
|
||||
|
||||
cc_pair_metadata = {"name": f"test_cc_pair_{unique_id}", "is_public": True}
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}",
|
||||
json=cc_pair_metadata,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# fetch the conenector credential pair id using the indexing status API
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/indexing-status"
|
||||
)
|
||||
response.raise_for_status()
|
||||
indexing_statuses = response.json()
|
||||
|
||||
cc_pair_id = None
|
||||
for status in indexing_statuses:
|
||||
if (
|
||||
status["connector"]["id"] == connector_id
|
||||
and status["credential"]["id"] == credential_id
|
||||
):
|
||||
cc_pair_id = status["cc_pair_id"]
|
||||
break
|
||||
|
||||
if cc_pair_id is None:
|
||||
raise ValueError("Could not find the connector credential pair id")
|
||||
|
||||
print(
|
||||
f"Created connector with connector_id: {connector_id}, credential_id: {credential_id}, cc_pair_id: {cc_pair_id}"
|
||||
)
|
||||
return ConnectorCreationDetails(
|
||||
connector_id=int(connector_id),
|
||||
credential_id=int(credential_id),
|
||||
cc_pair_id=int(cc_pair_id),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_connector_status(
|
||||
cc_pair_id: int, status: ConnectorCredentialPairStatus
|
||||
) -> None:
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/status",
|
||||
json={"status": status},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete_connector(connector_id: int, credential_id: int) -> None:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/deletion-attempt",
|
||||
json={"connector_id": connector_id, "credential_id": credential_id},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_connectors() -> list[dict]:
|
||||
response = requests.get(f"{API_SERVER_URL}/manage/connector")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
@ -5,3 +5,7 @@ API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
||||
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
||||
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
|
||||
MAX_DELAY = 30
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
NUM_DOCS = 5
|
||||
|
@ -1,30 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.server.features.document_set.models import DocumentSet
|
||||
from danswer.server.features.document_set.models import DocumentSetCreationRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
|
||||
|
||||
class DocumentSetClient:
|
||||
@staticmethod
|
||||
def create_document_set(
|
||||
doc_set_creation_request: DocumentSetCreationRequest,
|
||||
) -> int:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set",
|
||||
json=doc_set_creation_request.model_dump(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cast(int, response.json())
|
||||
|
||||
@staticmethod
|
||||
def fetch_document_sets() -> list[DocumentSet]:
|
||||
response = requests.get(f"{API_SERVER_URL}/manage/document-set")
|
||||
response.raise_for_status()
|
||||
|
||||
document_sets = [
|
||||
DocumentSet.parse_obj(doc_set_data) for doc_set_data in response.json()
|
||||
]
|
||||
return document_sets
|
@ -1,62 +1,88 @@
|
||||
import os
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestLLMProvider
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class LLMProvider(BaseModel):
|
||||
provider: str
|
||||
api_key: str
|
||||
default_model_name: str
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
is_default: bool = True
|
||||
class LLMProviderManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
provider: str | None = None,
|
||||
api_key: str | None = None,
|
||||
default_model_name: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
groups: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestLLMProvider:
|
||||
print("Seeding LLM Providers...")
|
||||
|
||||
# only populated after creation
|
||||
_provider_id: int | None = PrivateAttr()
|
||||
|
||||
def create(self) -> int:
|
||||
llm_provider = LLMProviderUpsertRequest(
|
||||
name=self.provider,
|
||||
provider=self.provider,
|
||||
default_model_name=self.default_model_name,
|
||||
api_key=self.api_key,
|
||||
api_base=self.api_base,
|
||||
api_version=self.api_version,
|
||||
name=name or f"test-provider-{uuid4()}",
|
||||
provider=provider or "openai",
|
||||
default_model_name=default_model_name or "gpt-4o-mini",
|
||||
api_key=api_key or os.environ["OPENAI_API_KEY"],
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
custom_config=None,
|
||||
fast_default_model_name=None,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
fast_default_model_name=default_model_name or "gpt-4o-mini",
|
||||
is_public=is_public or True,
|
||||
groups=groups or [],
|
||||
display_model_names=None,
|
||||
model_names=None,
|
||||
)
|
||||
|
||||
response = requests.put(
|
||||
llm_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
json=llm_provider.dict(),
|
||||
json=llm_provider.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
llm_response.raise_for_status()
|
||||
response_data = llm_response.json()
|
||||
result_llm = TestLLMProvider(
|
||||
id=response_data["id"],
|
||||
name=response_data["name"],
|
||||
provider=response_data["provider"],
|
||||
api_key=response_data["api_key"],
|
||||
default_model_name=response_data["default_model_name"],
|
||||
is_public=response_data["is_public"],
|
||||
groups=response_data["groups"],
|
||||
api_base=response_data["api_base"],
|
||||
api_version=response_data["api_version"],
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
self._provider_id = cast(int, response.json()["id"])
|
||||
return self._provider_id
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
set_default_response.raise_for_status()
|
||||
|
||||
def delete(self) -> None:
|
||||
return result_llm
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
llm_provider: TestLLMProvider,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> bool:
|
||||
if not llm_provider.id:
|
||||
raise ValueError("LLM Provider ID is required to delete a provider")
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{self._provider_id}"
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def seed_default_openai_provider() -> LLMProvider:
|
||||
llm = LLMProvider(
|
||||
provider="openai",
|
||||
default_model_name="gpt-4o-mini",
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
)
|
||||
llm.create()
|
||||
return llm
|
||||
return True
|
||||
|
92
backend/tests/integration/common_utils/managers/api_key.py
Normal file
92
backend/tests/integration/common_utils/managers/api_key.py
Normal file
@ -0,0 +1,92 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.db.models import UserRole
|
||||
from ee.danswer.server.api_key.models import APIKeyArgs
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestAPIKey
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
api_key_role: UserRole = UserRole.ADMIN,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestAPIKey:
|
||||
name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}"
|
||||
api_key_request = APIKeyArgs(
|
||||
name=name,
|
||||
role=api_key_role,
|
||||
)
|
||||
api_key_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/api-key",
|
||||
json=api_key_request.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
api_key = api_key_response.json()
|
||||
result_api_key = TestAPIKey(
|
||||
api_key_id=api_key["api_key_id"],
|
||||
api_key_display=api_key["api_key_display"],
|
||||
api_key=api_key["api_key"],
|
||||
api_key_name=name,
|
||||
api_key_role=api_key_role,
|
||||
user_id=api_key["user_id"],
|
||||
headers=GENERAL_HEADERS,
|
||||
)
|
||||
result_api_key.headers["Authorization"] = f"Bearer {result_api_key.api_key}"
|
||||
return result_api_key
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
api_key: TestAPIKey,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
api_key_response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[TestAPIKey]:
|
||||
api_key_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/api-key",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
api_key_response.raise_for_status()
|
||||
return [TestAPIKey(**api_key) for api_key in api_key_response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
api_key: TestAPIKey,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_keys = APIKeyManager.get_all(
|
||||
user_performing_action=user_performing_action
|
||||
)
|
||||
for key in retrieved_keys:
|
||||
if key.api_key_id == api_key.api_key_id:
|
||||
if verify_deleted:
|
||||
raise ValueError("API Key found when it should have been deleted")
|
||||
if (
|
||||
key.api_key_name == api_key.api_key_name
|
||||
and key.api_key_role == api_key.api_key_role
|
||||
):
|
||||
return
|
||||
|
||||
if not verify_deleted:
|
||||
raise Exception("API Key not found")
|
202
backend/tests/integration/common_utils/managers/cc_pair.py
Normal file
202
backend/tests/integration/common_utils/managers/cc_pair.py
Normal file
@ -0,0 +1,202 @@
|
||||
import time
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorIndexingStatus
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.test_models import TestCCPair
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
def _cc_pair_creator(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
name: str | None = None,
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCCPair:
|
||||
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"
|
||||
|
||||
request = {
|
||||
"name": name,
|
||||
"is_public": is_public,
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}",
|
||||
json=request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return TestCCPair(
|
||||
id=response.json()["data"],
|
||||
name=name,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
is_public=is_public,
|
||||
groups=groups or [],
|
||||
)
|
||||
|
||||
|
||||
class CCPairManager:
|
||||
@staticmethod
|
||||
def create_from_scratch(
|
||||
name: str | None = None,
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
source: DocumentSource = DocumentSource.FILE,
|
||||
input_type: InputType = InputType.LOAD_STATE,
|
||||
connector_specific_config: dict[str, Any] | None = None,
|
||||
credential_json: dict[str, Any] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCCPair:
|
||||
connector = ConnectorManager.create(
|
||||
name=name,
|
||||
source=source,
|
||||
input_type=input_type,
|
||||
connector_specific_config=connector_specific_config,
|
||||
is_public=is_public,
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
credential = CredentialManager.create(
|
||||
credential_json=credential_json,
|
||||
name=name,
|
||||
source=source,
|
||||
curator_public=is_public,
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
return _cc_pair_creator(
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
name=name,
|
||||
is_public=is_public,
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
name: str | None = None,
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCCPair:
|
||||
return _cc_pair_creator(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
name=name,
|
||||
is_public=is_public,
|
||||
groups=groups,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pause_cc_pair(
|
||||
cc_pair: TestCCPair,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
result = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
|
||||
json={"status": "PAUSED"},
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
cc_pair: TestCCPair,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
cc_pair_identifier = ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/deletion-attempt",
|
||||
json=cc_pair_identifier.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[ConnectorIndexingStatus]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [ConnectorIndexingStatus(**cc_pair) for cc_pair in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
cc_pair: TestCCPair,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
all_cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
for retrieved_cc_pair in all_cc_pairs:
|
||||
if retrieved_cc_pair.cc_pair_id == cc_pair.id:
|
||||
if verify_deleted:
|
||||
# We assume that this check will be performed after the deletion is
|
||||
# already waited for
|
||||
raise ValueError(
|
||||
f"CC pair {cc_pair.id} found but should be deleted"
|
||||
)
|
||||
if (
|
||||
retrieved_cc_pair.name == cc_pair.name
|
||||
and retrieved_cc_pair.connector.id == cc_pair.connector_id
|
||||
and retrieved_cc_pair.credential.id == cc_pair.credential_id
|
||||
and retrieved_cc_pair.public_doc == cc_pair.is_public
|
||||
and set(retrieved_cc_pair.groups) == set(cc_pair.groups)
|
||||
):
|
||||
return
|
||||
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"CC pair {cc_pair.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def wait_for_deletion_completion(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
start = time.time()
|
||||
while True:
|
||||
cc_pairs = CCPairManager.get_all(user_performing_action)
|
||||
if all(
|
||||
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
|
||||
for cc_pair in cc_pairs
|
||||
):
|
||||
return
|
||||
|
||||
if time.time() - start > MAX_DELAY:
|
||||
raise TimeoutError(
|
||||
f"CC pairs deletion was not completed within the {MAX_DELAY} seconds"
|
||||
)
|
||||
else:
|
||||
print("Some CC pairs are still being deleted, waiting...")
|
||||
time.sleep(2)
|
124
backend/tests/integration/common_utils/managers/connector.py
Normal file
124
backend/tests/integration/common_utils/managers/connector.py
Normal file
@ -0,0 +1,124 @@
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.server.documents.models import ConnectorUpdateRequest
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestConnector
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class ConnectorManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
source: DocumentSource = DocumentSource.FILE,
|
||||
input_type: InputType = InputType.LOAD_STATE,
|
||||
connector_specific_config: dict[str, Any] | None = None,
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestConnector:
|
||||
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
|
||||
|
||||
connector_update_request = ConnectorUpdateRequest(
|
||||
name=name,
|
||||
source=source,
|
||||
input_type=input_type,
|
||||
connector_specific_config=connector_specific_config or {},
|
||||
is_public=is_public,
|
||||
groups=groups or [],
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector",
|
||||
json=connector_update_request.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_data = response.json()
|
||||
return TestConnector(
|
||||
id=response_data.get("id"),
|
||||
name=name,
|
||||
source=source,
|
||||
input_type=input_type,
|
||||
connector_specific_config=connector_specific_config or {},
|
||||
groups=groups,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
connector: TestConnector,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
|
||||
json=connector.model_dump(exclude={"id"}),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
connector: TestConnector,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[TestConnector]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/connector",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
TestConnector(
|
||||
id=conn.get("id"),
|
||||
name=conn.get("name", ""),
|
||||
source=conn.get("source", DocumentSource.FILE),
|
||||
input_type=conn.get("input_type", InputType.LOAD_STATE),
|
||||
connector_specific_config=conn.get("connector_specific_config", {}),
|
||||
)
|
||||
for conn in response.json()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get(
|
||||
connector_id: int, user_performing_action: TestUser | None = None
|
||||
) -> TestConnector:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/connector/{connector_id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
conn = response.json()
|
||||
return TestConnector(
|
||||
id=conn.get("id"),
|
||||
name=conn.get("name", ""),
|
||||
source=conn.get("source", DocumentSource.FILE),
|
||||
input_type=conn.get("input_type", InputType.LOAD_STATE),
|
||||
connector_specific_config=conn.get("connector_specific_config", {}),
|
||||
)
|
129
backend/tests/integration/common_utils/managers/credential.py
Normal file
129
backend/tests/integration/common_utils/managers/credential.py
Normal file
@ -0,0 +1,129 @@
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.server.documents.models import CredentialSnapshot
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestCredential
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class CredentialManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
credential_json: dict[str, Any] | None = None,
|
||||
admin_public: bool = True,
|
||||
name: str | None = None,
|
||||
source: DocumentSource = DocumentSource.FILE,
|
||||
curator_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestCredential:
|
||||
name = f"{name}-credential" if name else f"test-credential-{uuid4()}"
|
||||
|
||||
credential_request = {
|
||||
"name": name,
|
||||
"credential_json": credential_json or {},
|
||||
"admin_public": admin_public,
|
||||
"source": source,
|
||||
"curator_public": curator_public,
|
||||
"groups": groups or [],
|
||||
}
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/credential",
|
||||
json=credential_request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
return TestCredential(
|
||||
id=response.json()["id"],
|
||||
name=name,
|
||||
credential_json=credential_json or {},
|
||||
admin_public=admin_public,
|
||||
source=source,
|
||||
curator_public=curator_public,
|
||||
groups=groups or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
credential: TestCredential,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
request = credential.model_dump(include={"name", "credential_json"})
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/credential/{credential.id}",
|
||||
json=request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
credential: TestCredential,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
response = requests.delete(
|
||||
url=f"{API_SERVER_URL}/manage/credential/{credential.id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get(
|
||||
credential_id: int, user_performing_action: TestUser | None = None
|
||||
) -> CredentialSnapshot:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/credential/{credential_id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return CredentialSnapshot(**response.json())
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[CredentialSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/credential",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [CredentialSnapshot(**cred) for cred in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
credential: TestCredential,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
all_credentials = CredentialManager.get_all(user_performing_action)
|
||||
for fetched_credential in all_credentials:
|
||||
if credential.id == fetched_credential.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
f"Credential {credential.id} found but should be deleted"
|
||||
)
|
||||
if (
|
||||
credential.name == fetched_credential.name
|
||||
and credential.admin_public == fetched_credential.admin_public
|
||||
and credential.source == fetched_credential.source
|
||||
and credential.curator_public == fetched_credential.curator_public
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"Credential {credential.id} not found")
|
153
backend/tests/integration/common_utils/managers/document.py
Normal file
153
backend/tests/integration/common_utils/managers/document.py
Normal file
@ -0,0 +1,153 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import TestAPIKey
|
||||
from tests.integration.common_utils.managers.cc_pair import TestCCPair
|
||||
from tests.integration.common_utils.test_models import SimpleTestDocument
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.vespa import TestVespaClient
|
||||
|
||||
|
||||
def _verify_document_permissions(
|
||||
retrieved_doc: dict,
|
||||
cc_pair: TestCCPair,
|
||||
doc_set_names: list[str] | None = None,
|
||||
group_names: list[str] | None = None,
|
||||
doc_creating_user: TestUser | None = None,
|
||||
) -> None:
|
||||
acl_keys = set(retrieved_doc["access_control_list"].keys())
|
||||
print(f"ACL keys: {acl_keys}")
|
||||
if cc_pair.is_public:
|
||||
if "PUBLIC" not in acl_keys:
|
||||
raise ValueError(
|
||||
f"Document {retrieved_doc['document_id']} is public but"
|
||||
" does not have the PUBLIC ACL key"
|
||||
)
|
||||
|
||||
if doc_creating_user is not None:
|
||||
if f"user_id:{doc_creating_user.id}" not in acl_keys:
|
||||
raise ValueError(
|
||||
f"Document {retrieved_doc['document_id']} was created by user"
|
||||
f" {doc_creating_user.id} but does not have the user_id:{doc_creating_user.id} ACL key"
|
||||
)
|
||||
|
||||
if group_names is not None:
|
||||
expected_group_keys = {f"group:{group_name}" for group_name in group_names}
|
||||
found_group_keys = {key for key in acl_keys if key.startswith("group:")}
|
||||
if found_group_keys != expected_group_keys:
|
||||
raise ValueError(
|
||||
f"Document {retrieved_doc['document_id']} has incorrect group ACL keys. Found: {found_group_keys}, \n"
|
||||
f"Expected: {expected_group_keys}"
|
||||
)
|
||||
|
||||
if doc_set_names is not None:
|
||||
found_doc_set_names = set(retrieved_doc.get("document_sets", {}).keys())
|
||||
if found_doc_set_names != set(doc_set_names):
|
||||
raise ValueError(
|
||||
f"Document set names mismatch. \nFound: {found_doc_set_names}, \n"
|
||||
f"Expected: {set(doc_set_names)}"
|
||||
)
|
||||
|
||||
|
||||
def _generate_dummy_document(document_id: str, cc_pair_id: int) -> dict:
|
||||
return {
|
||||
"document": {
|
||||
"id": document_id,
|
||||
"sections": [
|
||||
{
|
||||
"text": f"This is test document {document_id}",
|
||||
"link": f"{document_id}",
|
||||
}
|
||||
],
|
||||
"source": DocumentSource.NOT_APPLICABLE,
|
||||
# just for testing metadata
|
||||
"metadata": {"document_id": document_id},
|
||||
"semantic_identifier": f"Test Document {document_id}",
|
||||
"from_ingestion_api": True,
|
||||
},
|
||||
"cc_pair_id": cc_pair_id,
|
||||
}
|
||||
|
||||
|
||||
class DocumentManager:
|
||||
@staticmethod
|
||||
def seed_and_attach_docs(
|
||||
cc_pair: TestCCPair,
|
||||
num_docs: int = NUM_DOCS,
|
||||
document_ids: list[str] | None = None,
|
||||
api_key: TestAPIKey | None = None,
|
||||
) -> TestCCPair:
|
||||
# Use provided document_ids if available, otherwise generate random UUIDs
|
||||
if document_ids is None:
|
||||
document_ids = [f"test-doc-{uuid4()}" for _ in range(num_docs)]
|
||||
else:
|
||||
num_docs = len(document_ids)
|
||||
# Create and ingest some documents
|
||||
documents: list[dict] = []
|
||||
for document_id in document_ids:
|
||||
document = _generate_dummy_document(document_id, cc_pair.id)
|
||||
documents.append(document)
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/danswer-api/ingestion",
|
||||
json=document,
|
||||
headers=api_key.headers if api_key else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
print("Seeding completed successfully.")
|
||||
cc_pair.documents = [
|
||||
SimpleTestDocument(
|
||||
id=document["document"]["id"],
|
||||
content=document["document"]["sections"][0]["text"],
|
||||
)
|
||||
for document in documents
|
||||
]
|
||||
return cc_pair
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
vespa_client: TestVespaClient,
|
||||
cc_pair: TestCCPair,
|
||||
# If None, will not check doc sets or groups
|
||||
# If empty list, will check for empty doc sets or groups
|
||||
doc_set_names: list[str] | None = None,
|
||||
group_names: list[str] | None = None,
|
||||
doc_creating_user: TestUser | None = None,
|
||||
verify_deleted: bool = False,
|
||||
) -> None:
|
||||
doc_ids = [document.id for document in cc_pair.documents]
|
||||
retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"]
|
||||
retrieved_docs = {
|
||||
doc["fields"]["document_id"]: doc["fields"] for doc in retrieved_docs_dict
|
||||
}
|
||||
# Left this here for debugging purposes.
|
||||
# import json
|
||||
# for doc in retrieved_docs.values():
|
||||
# printable_doc = doc.copy()
|
||||
# print(printable_doc.keys())
|
||||
# printable_doc.pop("embeddings")
|
||||
# printable_doc.pop("title_embedding")
|
||||
# print(json.dumps(printable_doc, indent=2))
|
||||
|
||||
for document in cc_pair.documents:
|
||||
retrieved_doc = retrieved_docs.get(document.id)
|
||||
if not retrieved_doc:
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"Document not found: {document.id}")
|
||||
continue
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
f"Document found when it should be deleted: {document.id}"
|
||||
)
|
||||
_verify_document_permissions(
|
||||
retrieved_doc,
|
||||
cc_pair,
|
||||
doc_set_names,
|
||||
group_names,
|
||||
doc_creating_user,
|
||||
)
|
171
backend/tests/integration/common_utils/managers/document_set.py
Normal file
171
backend/tests/integration/common_utils/managers/document_set.py
Normal file
@ -0,0 +1,171 @@
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import TestDocumentSet
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class DocumentSetManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
cc_pair_ids: list[int] | None = None,
|
||||
is_public: bool = True,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestDocumentSet:
|
||||
if name is None:
|
||||
name = f"test_doc_set_{str(uuid4())}"
|
||||
|
||||
doc_set_creation_request = {
|
||||
"name": name,
|
||||
"description": description or name,
|
||||
"cc_pair_ids": cc_pair_ids or [],
|
||||
"is_public": is_public,
|
||||
"users": users or [],
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set",
|
||||
json=doc_set_creation_request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return TestDocumentSet(
|
||||
id=int(response.json()),
|
||||
name=name,
|
||||
description=description or name,
|
||||
cc_pair_ids=cc_pair_ids or [],
|
||||
is_public=is_public,
|
||||
is_up_to_date=True,
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
document_set: TestDocumentSet,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> bool:
|
||||
doc_set_update_request = {
|
||||
"id": document_set.id,
|
||||
"description": document_set.description,
|
||||
"cc_pair_ids": document_set.cc_pair_ids,
|
||||
"is_public": document_set.is_public,
|
||||
"users": document_set.users,
|
||||
"groups": document_set.groups,
|
||||
}
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set",
|
||||
json=doc_set_update_request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
document_set: TestDocumentSet,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[TestDocumentSet]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/document-set",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [
|
||||
TestDocumentSet(
|
||||
id=doc_set["id"],
|
||||
name=doc_set["name"],
|
||||
description=doc_set["description"],
|
||||
cc_pair_ids=[
|
||||
cc_pair["id"] for cc_pair in doc_set["cc_pair_descriptors"]
|
||||
],
|
||||
is_public=doc_set["is_public"],
|
||||
is_up_to_date=doc_set["is_up_to_date"],
|
||||
users=doc_set["users"],
|
||||
groups=doc_set["groups"],
|
||||
)
|
||||
for doc_set in response.json()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
document_sets_to_check: list[TestDocumentSet] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
# wait for document sets to be synced
|
||||
start = time.time()
|
||||
while True:
|
||||
doc_sets = DocumentSetManager.get_all(user_performing_action)
|
||||
if document_sets_to_check:
|
||||
check_ids = {doc_set.id for doc_set in document_sets_to_check}
|
||||
doc_set_ids = {doc_set.id for doc_set in doc_sets}
|
||||
if not check_ids.issubset(doc_set_ids):
|
||||
raise RuntimeError("Document set not found")
|
||||
doc_sets = [doc_set for doc_set in doc_sets if doc_set.id in check_ids]
|
||||
all_up_to_date = all(doc_set.is_up_to_date for doc_set in doc_sets)
|
||||
|
||||
if all_up_to_date:
|
||||
break
|
||||
|
||||
if time.time() - start > MAX_DELAY:
|
||||
raise TimeoutError(
|
||||
f"Document sets were not synced within the {MAX_DELAY} seconds"
|
||||
)
|
||||
else:
|
||||
print("Document sets were not synced yet, waiting...")
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
document_set: TestDocumentSet,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
doc_sets = DocumentSetManager.get_all(user_performing_action)
|
||||
for doc_set in doc_sets:
|
||||
if doc_set.id == document_set.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
f"Document set {document_set.id} found but should have been deleted"
|
||||
)
|
||||
if (
|
||||
doc_set.name == document_set.name
|
||||
and set(doc_set.cc_pair_ids) == set(document_set.cc_pair_ids)
|
||||
and doc_set.is_public == document_set.is_public
|
||||
and set(doc_set.users) == set(document_set.users)
|
||||
and set(doc_set.groups) == set(document_set.groups)
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"Document set {document_set.id} not found")
|
206
backend/tests/integration/common_utils/managers/persona.py
Normal file
206
backend/tests/integration/common_utils/managers/persona.py
Normal file
@ -0,0 +1,206 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestPersona
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class PersonaManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
num_chunks: float = 5,
|
||||
llm_relevance_filter: bool = True,
|
||||
is_public: bool = True,
|
||||
llm_filter_extraction: bool = True,
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
|
||||
prompt_ids: list[int] | None = None,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
llm_model_provider_override: str | None = None,
|
||||
llm_model_version_override: str | None = None,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
description = description or f"Description for {name}"
|
||||
|
||||
persona_creation_request = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"num_chunks": num_chunks,
|
||||
"llm_relevance_filter": llm_relevance_filter,
|
||||
"is_public": is_public,
|
||||
"llm_filter_extraction": llm_filter_extraction,
|
||||
"recency_bias": recency_bias,
|
||||
"prompt_ids": prompt_ids or [],
|
||||
"document_set_ids": document_set_ids or [],
|
||||
"tool_ids": tool_ids or [],
|
||||
"llm_model_provider_override": llm_model_provider_override,
|
||||
"llm_model_version_override": llm_model_version_override,
|
||||
"users": users or [],
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
json=persona_creation_request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
persona_data = response.json()
|
||||
|
||||
return TestPersona(
|
||||
id=persona_data["id"],
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
prompt_ids=prompt_ids or [],
|
||||
document_set_ids=document_set_ids or [],
|
||||
tool_ids=tool_ids or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
persona: TestPersona,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
num_chunks: float | None = None,
|
||||
llm_relevance_filter: bool | None = None,
|
||||
is_public: bool | None = None,
|
||||
llm_filter_extraction: bool | None = None,
|
||||
recency_bias: RecencyBiasSetting | None = None,
|
||||
prompt_ids: list[int] | None = None,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
llm_model_provider_override: str | None = None,
|
||||
llm_model_version_override: str | None = None,
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestPersona:
|
||||
persona_update_request = {
|
||||
"name": name or persona.name,
|
||||
"description": description or persona.description,
|
||||
"num_chunks": num_chunks or persona.num_chunks,
|
||||
"llm_relevance_filter": llm_relevance_filter
|
||||
or persona.llm_relevance_filter,
|
||||
"is_public": is_public or persona.is_public,
|
||||
"llm_filter_extraction": llm_filter_extraction
|
||||
or persona.llm_filter_extraction,
|
||||
"recency_bias": recency_bias or persona.recency_bias,
|
||||
"prompt_ids": prompt_ids or persona.prompt_ids,
|
||||
"document_set_ids": document_set_ids or persona.document_set_ids,
|
||||
"tool_ids": tool_ids or persona.tool_ids,
|
||||
"llm_model_provider_override": llm_model_provider_override
|
||||
or persona.llm_model_provider_override,
|
||||
"llm_model_version_override": llm_model_version_override
|
||||
or persona.llm_model_version_override,
|
||||
"users": users or persona.users,
|
||||
"groups": groups or persona.groups,
|
||||
}
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=persona_update_request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
updated_persona_data = response.json()
|
||||
|
||||
return TestPersona(
|
||||
id=updated_persona_data["id"],
|
||||
name=updated_persona_data["name"],
|
||||
description=updated_persona_data["description"],
|
||||
num_chunks=updated_persona_data["num_chunks"],
|
||||
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
|
||||
is_public=updated_persona_data["is_public"],
|
||||
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
|
||||
recency_bias=updated_persona_data["recency_bias"],
|
||||
prompt_ids=updated_persona_data["prompts"],
|
||||
document_set_ids=updated_persona_data["document_sets"],
|
||||
tool_ids=updated_persona_data["tools"],
|
||||
llm_model_provider_override=updated_persona_data[
|
||||
"llm_model_provider_override"
|
||||
],
|
||||
llm_model_version_override=updated_persona_data[
|
||||
"llm_model_version_override"
|
||||
],
|
||||
users=[user["email"] for user in updated_persona_data["users"]],
|
||||
groups=updated_persona_data["groups"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[PersonaSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/persona",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [PersonaSnapshot(**persona) for persona in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
test_persona: TestPersona,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> bool:
|
||||
all_personas = PersonaManager.get_all(user_performing_action)
|
||||
for persona in all_personas:
|
||||
if persona.id == test_persona.id:
|
||||
return (
|
||||
persona.name == test_persona.name
|
||||
and persona.description == test_persona.description
|
||||
and persona.num_chunks == test_persona.num_chunks
|
||||
and persona.llm_relevance_filter
|
||||
== test_persona.llm_relevance_filter
|
||||
and persona.is_public == test_persona.is_public
|
||||
and persona.llm_filter_extraction
|
||||
== test_persona.llm_filter_extraction
|
||||
and persona.llm_model_provider_override
|
||||
== test_persona.llm_model_provider_override
|
||||
and persona.llm_model_version_override
|
||||
== test_persona.llm_model_version_override
|
||||
and set(persona.prompts) == set(test_persona.prompt_ids)
|
||||
and set(persona.document_sets) == set(test_persona.document_set_ids)
|
||||
and set(persona.tools) == set(test_persona.tool_ids)
|
||||
and set(user.email for user in persona.users)
|
||||
== set(test_persona.users)
|
||||
and set(persona.groups) == set(test_persona.groups)
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
persona: TestPersona,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> bool:
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
return response.ok
|
122
backend/tests/integration/common_utils/managers/user.py
Normal file
122
backend/tests/integration/common_utils/managers/user.py
Normal file
@ -0,0 +1,122 @@
|
||||
from copy import deepcopy
|
||||
from urllib.parse import urlencode
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.manage.models import AllUsersResponse
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class UserManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
) -> TestUser:
|
||||
if name is None:
|
||||
name = f"test{str(uuid4())}"
|
||||
|
||||
email = f"{name}@test.com"
|
||||
password = "test"
|
||||
|
||||
body = {
|
||||
"email": email,
|
||||
"username": email,
|
||||
"password": password,
|
||||
}
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/auth/register",
|
||||
json=body,
|
||||
headers=GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
test_user = TestUser(
|
||||
id=response.json()["id"],
|
||||
email=email,
|
||||
password=password,
|
||||
headers=deepcopy(GENERAL_HEADERS),
|
||||
)
|
||||
print(f"Created user {test_user.email}")
|
||||
|
||||
test_user.headers["Cookie"] = UserManager.login_as_user(test_user)
|
||||
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
def login_as_user(test_user: TestUser) -> str:
|
||||
data = urlencode(
|
||||
{
|
||||
"username": test_user.email,
|
||||
"password": test_user.password,
|
||||
}
|
||||
)
|
||||
headers = test_user.headers.copy()
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/auth/login",
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result_cookie = next(iter(response.cookies), None)
|
||||
|
||||
if not result_cookie:
|
||||
raise Exception("Failed to login")
|
||||
|
||||
print(f"Logged in as {test_user.email}")
|
||||
return f"{result_cookie.name}={result_cookie.value}"
|
||||
|
||||
@staticmethod
|
||||
def verify_role(user_to_verify: TestUser, target_role: UserRole) -> bool:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me",
|
||||
headers=user_to_verify.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return target_role == UserRole(response.json().get("role", ""))
|
||||
|
||||
@staticmethod
|
||||
def set_role(
|
||||
user_to_set: TestUser,
|
||||
target_role: UserRole,
|
||||
user_to_perform_action: TestUser | None = None,
|
||||
) -> None:
|
||||
if user_to_perform_action is None:
|
||||
user_to_perform_action = user_to_set
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/set-user-role",
|
||||
json={"user_email": user_to_set.email, "new_role": target_role.value},
|
||||
headers=user_to_perform_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def verify(user: TestUser, user_to_perform_action: TestUser | None = None) -> None:
|
||||
if user_to_perform_action is None:
|
||||
user_to_perform_action = user
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/users",
|
||||
headers=user_to_perform_action.headers
|
||||
if user_to_perform_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
all_users = AllUsersResponse(
|
||||
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
|
||||
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
|
||||
accepted_pages=data["accepted_pages"],
|
||||
invited_pages=data["invited_pages"],
|
||||
)
|
||||
for accepted_user in all_users.accepted:
|
||||
if accepted_user.email == user.email and accepted_user.id == user.id:
|
||||
return
|
||||
raise ValueError(f"User {user.email} not found")
|
148
backend/tests/integration/common_utils/managers/user_group.py
Normal file
148
backend/tests/integration/common_utils/managers/user_group.py
Normal file
@ -0,0 +1,148 @@
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from ee.danswer.server.user_group.models import UserGroup
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import TestUserGroup
|
||||
|
||||
|
||||
class UserGroupManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
name: str | None = None,
|
||||
user_ids: list[str] | None = None,
|
||||
cc_pair_ids: list[int] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestUserGroup:
|
||||
name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}"
|
||||
|
||||
request = {
|
||||
"name": name,
|
||||
"user_ids": user_ids or [],
|
||||
"cc_pair_ids": cc_pair_ids or [],
|
||||
}
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
json=request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
test_user_group = TestUserGroup(
|
||||
id=response.json()["id"],
|
||||
name=response.json()["name"],
|
||||
user_ids=[user["id"] for user in response.json()["users"]],
|
||||
cc_pair_ids=[cc_pair["id"] for cc_pair in response.json()["cc_pairs"]],
|
||||
)
|
||||
return test_user_group
|
||||
|
||||
@staticmethod
|
||||
def edit(
|
||||
user_group: TestUserGroup,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
if not user_group.id:
|
||||
raise ValueError("User group has no ID")
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
|
||||
json=user_group.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def set_curator_status(
|
||||
test_user_group: TestUserGroup,
|
||||
user_to_set_as_curator: TestUser,
|
||||
is_curator: bool = True,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
if not user_to_set_as_curator.id:
|
||||
raise ValueError("User has no ID")
|
||||
set_curator_request = {
|
||||
"user_id": user_to_set_as_curator.id,
|
||||
"is_curator": is_curator,
|
||||
}
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{test_user_group.id}/set-curator",
|
||||
json=set_curator_request,
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[UserGroup]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserGroup(**ug) for ug in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
user_group: TestUserGroup,
|
||||
verify_deleted: bool = False,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
all_user_groups = UserGroupManager.get_all(user_performing_action)
|
||||
for fetched_user_group in all_user_groups:
|
||||
if user_group.id == fetched_user_group.id:
|
||||
if verify_deleted:
|
||||
raise ValueError(
|
||||
f"User group {user_group.id} found but should be deleted"
|
||||
)
|
||||
fetched_cc_ids = {cc_pair.id for cc_pair in fetched_user_group.cc_pairs}
|
||||
fetched_user_ids = {user.id for user in fetched_user_group.users}
|
||||
user_group_cc_ids = set(user_group.cc_pair_ids)
|
||||
user_group_user_ids = set(user_group.user_ids)
|
||||
if (
|
||||
fetched_cc_ids == user_group_cc_ids
|
||||
and fetched_user_ids == user_group_user_ids
|
||||
):
|
||||
return
|
||||
if not verify_deleted:
|
||||
raise ValueError(f"User group {user_group.id} not found")
|
||||
|
||||
@staticmethod
|
||||
def wait_for_sync(
|
||||
user_groups_to_check: list[TestUserGroup] | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> None:
|
||||
start = time.time()
|
||||
while True:
|
||||
user_groups = UserGroupManager.get_all(user_performing_action)
|
||||
if user_groups_to_check:
|
||||
check_ids = {user_group.id for user_group in user_groups_to_check}
|
||||
user_group_ids = {user_group.id for user_group in user_groups}
|
||||
if not check_ids.issubset(user_group_ids):
|
||||
raise RuntimeError("Document set not found")
|
||||
user_groups = [
|
||||
user_group
|
||||
for user_group in user_groups
|
||||
if user_group.id in check_ids
|
||||
]
|
||||
if all(ug.is_up_to_date for ug in user_groups):
|
||||
return
|
||||
|
||||
if time.time() - start > MAX_DELAY:
|
||||
raise TimeoutError(
|
||||
f"User groups were not synced within the {MAX_DELAY} seconds"
|
||||
)
|
||||
else:
|
||||
print("User groups were not synced yet, waiting...")
|
||||
time.sleep(2)
|
@ -20,7 +20,6 @@ from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.main import setup_postgres
|
||||
from danswer.main import setup_vespa
|
||||
from tests.integration.common_utils.llm import seed_default_openai_provider
|
||||
|
||||
|
||||
def _run_migrations(
|
||||
@ -167,6 +166,4 @@ def reset_all() -> None:
|
||||
reset_postgres()
|
||||
print("Resetting Vespa...")
|
||||
reset_vespa()
|
||||
print("Seeding LLM Providers...")
|
||||
seed_default_openai_provider()
|
||||
print("Finished resetting all.")
|
||||
|
@ -1,72 +0,0 @@
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from tests.integration.common_utils.connectors import ConnectorClient
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
|
||||
|
||||
class SimpleTestDocument(BaseModel):
|
||||
id: str
|
||||
content: str
|
||||
|
||||
|
||||
class SeedDocumentResponse(BaseModel):
|
||||
cc_pair_id: int
|
||||
documents: list[SimpleTestDocument]
|
||||
|
||||
|
||||
class TestDocumentClient:
|
||||
@staticmethod
|
||||
def seed_documents(
|
||||
num_docs: int = 5, cc_pair_id: int | None = None
|
||||
) -> SeedDocumentResponse:
|
||||
if not cc_pair_id:
|
||||
connector_details = ConnectorClient.create_connector()
|
||||
cc_pair_id = connector_details.cc_pair_id
|
||||
|
||||
# Create and ingest some documents
|
||||
documents: list[dict] = []
|
||||
for _ in range(num_docs):
|
||||
document_id = f"test-doc-{uuid.uuid4()}"
|
||||
document = {
|
||||
"document": {
|
||||
"id": document_id,
|
||||
"sections": [
|
||||
{
|
||||
"text": f"This is test document {document_id}",
|
||||
"link": f"{document_id}",
|
||||
}
|
||||
],
|
||||
"source": DocumentSource.NOT_APPLICABLE,
|
||||
# just for testing metadata
|
||||
"metadata": {"document_id": document_id},
|
||||
"semantic_identifier": f"Test Document {document_id}",
|
||||
"from_ingestion_api": True,
|
||||
},
|
||||
"cc_pair_id": cc_pair_id,
|
||||
}
|
||||
documents.append(document)
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/danswer-api/ingestion",
|
||||
json=document,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
print("Seeding completed successfully.")
|
||||
return SeedDocumentResponse(
|
||||
cc_pair_id=cc_pair_id,
|
||||
documents=[
|
||||
SimpleTestDocument(
|
||||
id=document["document"]["id"],
|
||||
content=document["document"]["sections"][0]["text"],
|
||||
)
|
||||
for document in documents
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed_documents_resp = TestDocumentClient.seed_documents()
|
120
backend/tests/integration/common_utils/test_models.py
Normal file
120
backend/tests/integration/common_utils/test_models.py
Normal file
@ -0,0 +1,120 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from danswer.server.documents.models import InputType
|
||||
|
||||
"""
|
||||
These data models are used to represent the data on the testing side of things.
|
||||
This means the flow is:
|
||||
1. Make request that changes data in db
|
||||
2. Make a change to the testing model
|
||||
3. Retrieve data from db
|
||||
4. Compare db data with testing model to verify
|
||||
"""
|
||||
|
||||
|
||||
class TestAPIKey(BaseModel):
|
||||
api_key_id: int
|
||||
api_key_display: str
|
||||
api_key: str | None = None # only present on initial creation
|
||||
api_key_name: str | None = None
|
||||
api_key_role: UserRole
|
||||
|
||||
user_id: UUID
|
||||
headers: dict
|
||||
|
||||
|
||||
class TestUser(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
password: str
|
||||
headers: dict
|
||||
|
||||
|
||||
class TestCredential(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
credential_json: dict[str, Any]
|
||||
admin_public: bool
|
||||
source: DocumentSource
|
||||
curator_public: bool
|
||||
groups: list[int]
|
||||
|
||||
|
||||
class TestConnector(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
source: DocumentSource
|
||||
input_type: InputType
|
||||
connector_specific_config: dict[str, Any]
|
||||
groups: list[int] | None = None
|
||||
is_public: bool | None = None
|
||||
|
||||
|
||||
class SimpleTestDocument(BaseModel):
|
||||
id: str
|
||||
content: str
|
||||
|
||||
|
||||
class TestCCPair(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
is_public: bool
|
||||
groups: list[int]
|
||||
documents: list[SimpleTestDocument] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TestUserGroup(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
user_ids: list[str]
|
||||
cc_pair_ids: list[int]
|
||||
|
||||
|
||||
class TestLLMProvider(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
default_model_name: str
|
||||
is_public: bool
|
||||
groups: list[TestUserGroup]
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
|
||||
|
||||
class TestDocumentSet(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
cc_pair_ids: list[int] = Field(default_factory=list)
|
||||
is_public: bool
|
||||
is_up_to_date: bool
|
||||
users: list[str] = Field(default_factory=list)
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TestPersona(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
num_chunks: float
|
||||
llm_relevance_filter: bool
|
||||
is_public: bool
|
||||
llm_filter_extraction: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
prompt_ids: list[int]
|
||||
document_set_ids: list[int]
|
||||
tool_ids: list[int]
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
users: list[str]
|
||||
groups: list[int]
|
@ -1,24 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
|
||||
from ee.danswer.server.user_group.models import UserGroup
|
||||
from ee.danswer.server.user_group.models import UserGroupCreate
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
|
||||
|
||||
class UserGroupClient:
|
||||
@staticmethod
|
||||
def create_user_group(user_group_creation_request: UserGroupCreate) -> int:
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
json=user_group_creation_request.model_dump(),
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cast(int, response.json()["id"])
|
||||
|
||||
@staticmethod
|
||||
def fetch_user_groups() -> list[UserGroup]:
|
||||
response = requests.get(f"{API_SERVER_URL}/manage/admin/user-group")
|
||||
response.raise_for_status()
|
||||
return [UserGroup(**ug) for ug in response.json()]
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
@ -9,6 +10,25 @@ from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.vespa import TestVespaClient
|
||||
|
||||
|
||||
def load_env_vars(env_file: str = ".env") -> None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
env_path = os.path.join(current_dir, env_file)
|
||||
try:
|
||||
with open(env_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
key, value = line.split("=", 1)
|
||||
os.environ[key] = value.strip()
|
||||
print("Successfully loaded environment variables")
|
||||
except FileNotFoundError:
|
||||
print(f"File {env_file} not found")
|
||||
|
||||
|
||||
# Load environment variables at the module level
|
||||
load_env_vars()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_session() -> Generator[Session, None, None]:
|
||||
with get_session_context_manager() as session:
|
||||
|
@ -1,190 +1,305 @@
|
||||
import time
|
||||
"""
|
||||
This file contains tests for the following:
|
||||
- Ensuring deletion of a connector also:
|
||||
- deletes the documents in vespa for that connector
|
||||
- updates the document sets and user groups to remove the connector
|
||||
- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected
|
||||
"""
|
||||
from uuid import uuid4
|
||||
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.server.features.document_set.models import DocumentSetCreationRequest
|
||||
from tests.integration.common_utils.connectors import ConnectorClient
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.document_sets import DocumentSetClient
|
||||
from tests.integration.common_utils.seed_documents import TestDocumentClient
|
||||
from tests.integration.common_utils.user_groups import UserGroupClient
|
||||
from tests.integration.common_utils.user_groups import UserGroupCreate
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.document_set import DocumentSetManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import TestAPIKey
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.test_models import TestUserGroup
|
||||
from tests.integration.common_utils.vespa import TestVespaClient
|
||||
|
||||
|
||||
def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None:
|
||||
# create connectors
|
||||
c1_details = ConnectorClient.create_connector(name_prefix="tc1")
|
||||
c2_details = ConnectorClient.create_connector(name_prefix="tc2")
|
||||
c1_seed_res = TestDocumentClient.seed_documents(
|
||||
num_docs=5, cc_pair_id=c1_details.cc_pair_id
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
# add api key to user
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
c2_seed_res = TestDocumentClient.seed_documents(
|
||||
num_docs=5, cc_pair_id=c2_details.cc_pair_id
|
||||
|
||||
# create connectors
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_2 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# seed documents
|
||||
cc_pair_1 = DocumentManager.seed_and_attach_docs(
|
||||
cc_pair=cc_pair_1,
|
||||
num_docs=NUM_DOCS,
|
||||
api_key=api_key,
|
||||
)
|
||||
cc_pair_2 = DocumentManager.seed_and_attach_docs(
|
||||
cc_pair=cc_pair_2,
|
||||
num_docs=NUM_DOCS,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# create document sets
|
||||
doc_set_1_id = DocumentSetClient.create_document_set(
|
||||
DocumentSetCreationRequest(
|
||||
name="Test Document Set 1",
|
||||
description="Intially connector to be deleted, should be empty after test",
|
||||
cc_pair_ids=[c1_details.cc_pair_id],
|
||||
is_public=True,
|
||||
users=[],
|
||||
groups=[],
|
||||
)
|
||||
doc_set_1 = DocumentSetManager.create(
|
||||
name="Test Document Set 1",
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
doc_set_2_id = DocumentSetClient.create_document_set(
|
||||
DocumentSetCreationRequest(
|
||||
name="Test Document Set 2",
|
||||
description="Intially both connectors, should contain undeleted connector after test",
|
||||
cc_pair_ids=[c1_details.cc_pair_id, c2_details.cc_pair_id],
|
||||
is_public=True,
|
||||
users=[],
|
||||
groups=[],
|
||||
)
|
||||
doc_set_2 = DocumentSetManager.create(
|
||||
name="Test Document Set 2",
|
||||
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# wait for document sets to be synced
|
||||
start = time.time()
|
||||
while True:
|
||||
doc_sets = DocumentSetClient.fetch_document_sets()
|
||||
doc_set_1 = next(
|
||||
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None
|
||||
)
|
||||
doc_set_2 = next(
|
||||
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None
|
||||
)
|
||||
|
||||
if not doc_set_1 or not doc_set_2:
|
||||
raise RuntimeError("Document set not found")
|
||||
|
||||
if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date:
|
||||
break
|
||||
|
||||
if time.time() - start > MAX_DELAY:
|
||||
raise TimeoutError("Document sets were not synced within the max delay")
|
||||
|
||||
time.sleep(2)
|
||||
DocumentSetManager.wait_for_sync(user_performing_action=admin_user)
|
||||
|
||||
print("Document sets created and synced")
|
||||
|
||||
# if so, create ACLs
|
||||
user_group_1 = UserGroupClient.create_user_group(
|
||||
UserGroupCreate(
|
||||
name="Test User Group 1", user_ids=[], cc_pair_ids=[c1_details.cc_pair_id]
|
||||
)
|
||||
# create user groups
|
||||
user_group_1: TestUserGroup = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
user_group_2 = UserGroupClient.create_user_group(
|
||||
UserGroupCreate(
|
||||
name="Test User Group 2",
|
||||
user_ids=[],
|
||||
cc_pair_ids=[c1_details.cc_pair_id, c2_details.cc_pair_id],
|
||||
)
|
||||
user_group_2: TestUserGroup = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# wait for user groups to be available
|
||||
start = time.time()
|
||||
while True:
|
||||
user_groups = {ug.id: ug for ug in UserGroupClient.fetch_user_groups()}
|
||||
|
||||
if not (
|
||||
user_group_1 in user_groups.keys() and user_group_2 in user_groups.keys()
|
||||
):
|
||||
raise RuntimeError("User groups not found")
|
||||
|
||||
if (
|
||||
user_groups[user_group_1].is_up_to_date
|
||||
and user_groups[user_group_2].is_up_to_date
|
||||
):
|
||||
break
|
||||
|
||||
if time.time() - start > MAX_DELAY:
|
||||
raise TimeoutError("User groups were not synced within the max delay")
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
print("User groups created and synced")
|
||||
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
|
||||
|
||||
# delete connector 1
|
||||
ConnectorClient.update_connector_status(
|
||||
cc_pair_id=c1_details.cc_pair_id, status=ConnectorCredentialPairStatus.PAUSED
|
||||
CCPairManager.pause_cc_pair(
|
||||
cc_pair=cc_pair_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
ConnectorClient.delete_connector(
|
||||
connector_id=c1_details.connector_id, credential_id=c1_details.credential_id
|
||||
CCPairManager.delete(
|
||||
cc_pair=cc_pair_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
start = time.time()
|
||||
while True:
|
||||
connectors = ConnectorClient.get_connectors()
|
||||
# Update local records to match the database for later comparison
|
||||
user_group_1.cc_pair_ids = []
|
||||
user_group_2.cc_pair_ids = [cc_pair_2.id]
|
||||
doc_set_1.cc_pair_ids = []
|
||||
doc_set_2.cc_pair_ids = [cc_pair_2.id]
|
||||
cc_pair_1.groups = []
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
|
||||
if c1_details.connector_id not in [c["id"] for c in connectors]:
|
||||
break
|
||||
|
||||
if time.time() - start > MAX_DELAY:
|
||||
raise TimeoutError("Connector 1 was not deleted within the max delay")
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
print("Connector 1 deleted")
|
||||
CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user)
|
||||
|
||||
# validate vespa documents
|
||||
c1_vespa_docs = vespa_client.get_documents_by_id(
|
||||
[doc.id for doc in c1_seed_res.documents]
|
||||
)["documents"]
|
||||
c2_vespa_docs = vespa_client.get_documents_by_id(
|
||||
[doc.id for doc in c2_seed_res.documents]
|
||||
)["documents"]
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_1,
|
||||
doc_set_names=[],
|
||||
group_names=[],
|
||||
doc_creating_user=admin_user,
|
||||
verify_deleted=True,
|
||||
)
|
||||
|
||||
assert len(c1_vespa_docs) == 0
|
||||
assert len(c2_vespa_docs) == 5
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
doc_set_names=[doc_set_2.name],
|
||||
group_names=[user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
verify_deleted=False,
|
||||
)
|
||||
|
||||
for doc in c2_vespa_docs:
|
||||
assert doc["fields"]["access_control_list"] == {
|
||||
"PUBLIC": 1,
|
||||
"group:Test User Group 2": 1,
|
||||
}
|
||||
assert doc["fields"]["document_sets"] == {"Test Document Set 2": 1}
|
||||
# check that only connector 1 is deleted
|
||||
CCPairManager.verify(
|
||||
cc_pair=cc_pair_2,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate document sets
|
||||
DocumentSetManager.verify(
|
||||
document_set=doc_set_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
DocumentSetManager.verify(
|
||||
document_set=doc_set_2,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate user groups
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_2,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
def test_connector_deletion_for_overlapping_connectors(
|
||||
reset: None, vespa_client: TestVespaClient
|
||||
) -> None:
|
||||
"""Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping
|
||||
document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors.
|
||||
"""
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
# add api key to user
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# create connectors
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_2 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
doc_ids = [str(uuid4())]
|
||||
cc_pair_1 = DocumentManager.seed_and_attach_docs(
|
||||
cc_pair=cc_pair_1,
|
||||
document_ids=doc_ids,
|
||||
api_key=api_key,
|
||||
)
|
||||
cc_pair_2 = DocumentManager.seed_and_attach_docs(
|
||||
cc_pair=cc_pair_2,
|
||||
document_ids=doc_ids,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# verify vespa document exists and that it is not in any document sets or groups
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_1,
|
||||
doc_set_names=[],
|
||||
group_names=[],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
doc_set_names=[],
|
||||
group_names=[],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
# create document set
|
||||
doc_set_1 = DocumentSetManager.create(
|
||||
name="Test Document Set 1",
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
DocumentSetManager.wait_for_sync(
|
||||
document_sets_to_check=[doc_set_1],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
print("Document set 1 created and synced")
|
||||
|
||||
# verify vespa document is in the document set
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_1,
|
||||
doc_set_names=[doc_set_1.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
# create a user group and attach it to connector 1
|
||||
user_group_1: TestUserGroup = UserGroupManager.create(
|
||||
name="Test User Group 1",
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_1.groups = [user_group_1.id]
|
||||
|
||||
print("User group 1 created and synced")
|
||||
|
||||
# create a user group and attach it to connector 2
|
||||
user_group_2: TestUserGroup = UserGroupManager.create(
|
||||
name="Test User Group 2",
|
||||
cc_pair_ids=[cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_2],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
|
||||
print("User group 2 created and synced")
|
||||
|
||||
# verify vespa document is in the user group
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_1,
|
||||
group_names=[user_group_1.name, user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
group_names=[user_group_1.name, user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
# delete connector 1
|
||||
CCPairManager.pause_cc_pair(
|
||||
cc_pair=cc_pair_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.delete(
|
||||
cc_pair=cc_pair_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# EVERYTHING BELOW HERE IS CURRENTLY BROKEN AND NEEDS TO BE FIXED SERVER SIDE
|
||||
|
||||
# wait for deletion to finish
|
||||
# CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user)
|
||||
|
||||
# print("Connector 1 deleted")
|
||||
|
||||
# check that only connector 1 is deleted
|
||||
# TODO: check for the CC pair rather than the connector once the refactor is done
|
||||
all_connectors = ConnectorClient.get_connectors()
|
||||
assert len(all_connectors) == 1
|
||||
assert all_connectors[0]["id"] == c2_details.connector_id
|
||||
# CCPairManager.verify(
|
||||
# cc_pair=cc_pair_1,
|
||||
# verify_deleted=True,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
# CCPairManager.verify(
|
||||
# cc_pair=cc_pair_2,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
|
||||
# validate document sets
|
||||
all_doc_sets = DocumentSetClient.fetch_document_sets()
|
||||
assert len(all_doc_sets) == 2
|
||||
|
||||
doc_set_1_found = False
|
||||
doc_set_2_found = False
|
||||
for doc_set in all_doc_sets:
|
||||
if doc_set.id == doc_set_1_id:
|
||||
doc_set_1_found = True
|
||||
assert doc_set.cc_pair_descriptors == []
|
||||
|
||||
if doc_set.id == doc_set_2_id:
|
||||
doc_set_2_found = True
|
||||
assert len(doc_set.cc_pair_descriptors) == 1
|
||||
assert doc_set.cc_pair_descriptors[0].id == c2_details.cc_pair_id
|
||||
|
||||
assert doc_set_1_found
|
||||
assert doc_set_2_found
|
||||
|
||||
# validate user groups
|
||||
all_user_groups = UserGroupClient.fetch_user_groups()
|
||||
assert len(all_user_groups) == 2
|
||||
|
||||
user_group_1_found = False
|
||||
user_group_2_found = False
|
||||
for user_group in all_user_groups:
|
||||
if user_group.id == user_group_1:
|
||||
user_group_1_found = True
|
||||
assert user_group.cc_pairs == []
|
||||
if user_group.id == user_group_2:
|
||||
user_group_2_found = True
|
||||
assert len(user_group.cc_pairs) == 1
|
||||
assert user_group.cc_pairs[0].id == c2_details.cc_pair_id
|
||||
|
||||
assert user_group_1_found
|
||||
assert user_group_2_found
|
||||
# verify the document is not in any document sets
|
||||
# verify the document is only in user group 2
|
||||
# DocumentManager.verify(
|
||||
# vespa_client=vespa_client,
|
||||
# cc_pair=cc_pair_2,
|
||||
# doc_set_names=[],
|
||||
# group_names=[user_group_2.name],
|
||||
# doc_creating_user=admin_user,
|
||||
# verify_deleted=False,
|
||||
# )
|
||||
|
@ -1,34 +1,59 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.connectors import ConnectorClient
|
||||
from danswer.configs.constants import MessageType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.seed_documents import TestDocumentClient
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.llm import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import TestAPIKey
|
||||
from tests.integration.common_utils.test_models import TestCCPair
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
def test_send_message_simple_with_history(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# create connectors
|
||||
c1_details = ConnectorClient.create_connector(name_prefix="tc1")
|
||||
c1_seed_res = TestDocumentClient.seed_documents(
|
||||
num_docs=5, cc_pair_id=c1_details.cc_pair_id
|
||||
cc_pair_1: TestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
cc_pair_1 = DocumentManager.seed_and_attach_docs(
|
||||
cc_pair=cc_pair_1,
|
||||
num_docs=NUM_DOCS,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [{"message": c1_seed_res.documents[0].content, "role": "user"}],
|
||||
"messages": [
|
||||
{
|
||||
"message": cc_pair_1.documents[0].content,
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
"prompt_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Check that the top document is the correct document
|
||||
assert response_json["simple_search_docs"][0]["id"] == c1_seed_res.documents[0].id
|
||||
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
|
||||
|
||||
# assert that the metadata is correct
|
||||
for doc in c1_seed_res.documents:
|
||||
for doc in cc_pair_1.documents:
|
||||
found_doc = next(
|
||||
(x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None
|
||||
)
|
||||
|
@ -1,78 +1,66 @@
|
||||
import time
|
||||
|
||||
from danswer.server.features.document_set.models import DocumentSetCreationRequest
|
||||
from tests.integration.common_utils.document_sets import DocumentSetClient
|
||||
from tests.integration.common_utils.seed_documents import TestDocumentClient
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.document_set import DocumentSetManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import TestAPIKey
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
from tests.integration.common_utils.vespa import TestVespaClient
|
||||
|
||||
|
||||
def test_multiple_document_sets_syncing_same_connnector(
|
||||
reset: None, vespa_client: TestVespaClient
|
||||
) -> None:
|
||||
# Seed documents
|
||||
seed_result = TestDocumentClient.seed_documents(num_docs=5)
|
||||
cc_pair_id = seed_result.cc_pair_id
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Create first document set
|
||||
doc_set_1_id = DocumentSetClient.create_document_set(
|
||||
DocumentSetCreationRequest(
|
||||
name="Test Document Set 1",
|
||||
description="First test document set",
|
||||
cc_pair_ids=[cc_pair_id],
|
||||
is_public=True,
|
||||
users=[],
|
||||
groups=[],
|
||||
)
|
||||
# add api key to user
|
||||
api_key: TestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
doc_set_2_id = DocumentSetClient.create_document_set(
|
||||
DocumentSetCreationRequest(
|
||||
name="Test Document Set 2",
|
||||
description="Second test document set",
|
||||
cc_pair_ids=[cc_pair_id],
|
||||
is_public=True,
|
||||
users=[],
|
||||
groups=[],
|
||||
)
|
||||
# create connector
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# wait for syncing to be complete
|
||||
max_delay = 45
|
||||
start = time.time()
|
||||
while True:
|
||||
doc_sets = DocumentSetClient.fetch_document_sets()
|
||||
doc_set_1 = next(
|
||||
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None
|
||||
)
|
||||
doc_set_2 = next(
|
||||
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None
|
||||
)
|
||||
# seed documents
|
||||
cc_pair_1 = DocumentManager.seed_and_attach_docs(
|
||||
cc_pair=cc_pair_1,
|
||||
num_docs=NUM_DOCS,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
if not doc_set_1 or not doc_set_2:
|
||||
raise RuntimeError("Document set not found")
|
||||
# Create document sets
|
||||
doc_set_1 = DocumentSetManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
doc_set_2 = DocumentSetManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date:
|
||||
assert [ccp.id for ccp in doc_set_1.cc_pair_descriptors] == [
|
||||
ccp.id for ccp in doc_set_2.cc_pair_descriptors
|
||||
]
|
||||
break
|
||||
DocumentSetManager.wait_for_sync(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
if time.time() - start > max_delay:
|
||||
raise TimeoutError("Document sets were not synced within the max delay")
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
# get names so we can compare to what is in vespa
|
||||
doc_sets = DocumentSetClient.fetch_document_sets()
|
||||
doc_set_names = {doc_set.name for doc_set in doc_sets}
|
||||
DocumentSetManager.verify(
|
||||
document_set=doc_set_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
DocumentSetManager.verify(
|
||||
document_set=doc_set_2,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# make sure documents are as expected
|
||||
seeded_document_ids = [doc.id for doc in seed_result.documents]
|
||||
|
||||
result = vespa_client.get_documents_by_id([doc.id for doc in seed_result.documents])
|
||||
documents = result["documents"]
|
||||
assert len(documents) == len(seed_result.documents)
|
||||
assert all(doc["fields"]["document_id"] in seeded_document_ids for doc in documents)
|
||||
assert all(
|
||||
set(doc["fields"]["document_sets"].keys()) == doc_set_names for doc in documents
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_1,
|
||||
doc_set_names=[doc_set_1.name, doc_set_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
@ -0,0 +1,179 @@
|
||||
"""
|
||||
This file takes the happy path to adding a curator to a user group and then tests
|
||||
the permissions of the curator manipulating connector-credential pairs.
|
||||
"""
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.user import TestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def test_cc_pair_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Creating a curator
|
||||
curator: TestUser = UserManager.create(name="curator")
|
||||
|
||||
# Creating a user group
|
||||
user_group_1 = UserGroupManager.create(
|
||||
name="curated_user_group",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
# setting the user as a curator for the user group
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=user_group_1,
|
||||
user_to_set_as_curator=curator,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Creating another user group that the user is not a curator of
|
||||
user_group_2 = UserGroupManager.create(
|
||||
name="uncurated_user_group",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# Create a credentials that the curator is and is not curator of
|
||||
connector_1 = ConnectorManager.create(
|
||||
name="curator_owned_connector",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
is_public=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
# currently we dont enforce permissions at the connector level
|
||||
# pending cc_pair -> connector rework
|
||||
# connector_2 = ConnectorManager.create(
|
||||
# name="curator_visible_connector",
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# groups=[user_group_2.id],
|
||||
# is_public=False,
|
||||
# user_performing_action=admin_user,
|
||||
# )
|
||||
credential_1 = CredentialManager.create(
|
||||
name="curator_owned_credential",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
curator_public=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
credential_2 = CredentialManager.create(
|
||||
name="curator_visible_credential",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_2.id],
|
||||
curator_public=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# END OF HAPPY PATH
|
||||
|
||||
"""Tests for things Curators should not be able to do"""
|
||||
|
||||
# Curators should not be able to create a public cc pair
|
||||
with pytest.raises(HTTPError):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
name="invalid_cc_pair_1",
|
||||
groups=[user_group_1.id],
|
||||
is_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should not be able to create a cc
|
||||
# pair for a user group they are not a curator of
|
||||
with pytest.raises(HTTPError):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
name="invalid_cc_pair_2",
|
||||
groups=[user_group_1.id, user_group_2.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should not be able to create a cc
|
||||
# pair without an attached user group
|
||||
with pytest.raises(HTTPError):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
name="invalid_cc_pair_2",
|
||||
groups=[],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# # This test is currently disabled because permissions are
|
||||
# # not enforced at the connector level
|
||||
# # Curators should not be able to create a cc pair
|
||||
# # for a user group that the connector does not belong to (NOT WORKING)
|
||||
# with pytest.raises(HTTPError):
|
||||
# CCPairManager.create(
|
||||
# connector_id=connector_2.id,
|
||||
# credential_id=credential_1.id,
|
||||
# name="invalid_cc_pair_3",
|
||||
# groups=[user_group_1.id],
|
||||
# is_public=False,
|
||||
# user_performing_action=curator,
|
||||
# )
|
||||
|
||||
# Curators should not be able to create a cc
|
||||
# pair for a user group that the credential does not belong to
|
||||
with pytest.raises(HTTPError):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_2.id,
|
||||
name="invalid_cc_pair_4",
|
||||
groups=[user_group_1.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
"""Tests for things Curators should be able to do"""
|
||||
|
||||
# Curators should be able to create a private
|
||||
# cc pair for a user group they are a curator of
|
||||
valid_cc_pair = CCPairManager.create(
|
||||
name="valid_cc_pair",
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
groups=[user_group_1.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Verify the created cc pair
|
||||
CCPairManager.verify(
|
||||
cc_pair=valid_cc_pair,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Test pausing the cc pair
|
||||
CCPairManager.pause_cc_pair(valid_cc_pair, user_performing_action=curator)
|
||||
|
||||
# Test deleting the cc pair
|
||||
CCPairManager.delete(valid_cc_pair, user_performing_action=curator)
|
||||
CCPairManager.wait_for_deletion_completion(user_performing_action=curator)
|
||||
|
||||
CCPairManager.verify(
|
||||
cc_pair=valid_cc_pair,
|
||||
verify_deleted=True,
|
||||
user_performing_action=curator,
|
||||
)
|
@ -0,0 +1,136 @@
|
||||
"""
|
||||
This file takes the happy path to adding a curator to a user group and then tests
|
||||
the permissions of the curator manipulating connectors.
|
||||
"""
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.user import TestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def test_connector_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Creating a curator
|
||||
curator: TestUser = UserManager.create(name="curator")
|
||||
|
||||
# Creating a user group
|
||||
user_group_1 = UserGroupManager.create(
|
||||
name="user_group_1",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
# setting the user as a curator for the user group
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=user_group_1,
|
||||
user_to_set_as_curator=curator,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Creating another user group that the user is not a curator of
|
||||
user_group_2 = UserGroupManager.create(
|
||||
name="user_group_2",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# END OF HAPPY PATH
|
||||
|
||||
"""Tests for things Curators should not be able to do"""
|
||||
|
||||
# Curators should not be able to create a public connector
|
||||
with pytest.raises(HTTPError):
|
||||
ConnectorManager.create(
|
||||
name="invalid_connector_1",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
is_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should not be able to create a cc pair for a
|
||||
# user group they are not a curator of
|
||||
with pytest.raises(HTTPError):
|
||||
ConnectorManager.create(
|
||||
name="invalid_connector_2",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id, user_group_2.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
"""Tests for things Curators should be able to do"""
|
||||
|
||||
# Curators should be able to create a private
|
||||
# connector for a user group they are a curator of
|
||||
valid_connector = ConnectorManager.create(
|
||||
name="valid_connector",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
assert valid_connector.id is not None
|
||||
|
||||
# Verify the created connector
|
||||
created_connector = ConnectorManager.get(
|
||||
valid_connector.id, user_performing_action=curator
|
||||
)
|
||||
assert created_connector.name == valid_connector.name
|
||||
assert created_connector.source == valid_connector.source
|
||||
|
||||
# Verify that the connector can be found in the list of all connectors
|
||||
all_connectors = ConnectorManager.get_all(user_performing_action=curator)
|
||||
assert any(conn.id == valid_connector.id for conn in all_connectors)
|
||||
|
||||
# Test editing the connector
|
||||
valid_connector.name = "updated_valid_connector"
|
||||
ConnectorManager.edit(valid_connector, user_performing_action=curator)
|
||||
|
||||
# Verify the edit
|
||||
updated_connector = ConnectorManager.get(
|
||||
valid_connector.id, user_performing_action=curator
|
||||
)
|
||||
assert updated_connector.name == "updated_valid_connector"
|
||||
|
||||
# Test deleting the connector
|
||||
ConnectorManager.delete(connector=valid_connector, user_performing_action=curator)
|
||||
|
||||
# Verify the deletion
|
||||
all_connectors_after_delete = ConnectorManager.get_all(
|
||||
user_performing_action=curator
|
||||
)
|
||||
assert all(conn.id != valid_connector.id for conn in all_connectors_after_delete)
|
||||
|
||||
# Test that curator cannot create a connector for a group they are not a curator of
|
||||
with pytest.raises(HTTPError):
|
||||
ConnectorManager.create(
|
||||
name="invalid_connector_3",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_2.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Test that curator cannot create a public connector
|
||||
with pytest.raises(HTTPError):
|
||||
ConnectorManager.create(
|
||||
name="invalid_connector_4",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
is_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
@ -0,0 +1,108 @@
|
||||
"""
|
||||
This file takes the happy path to adding a curator to a user group and then tests
|
||||
the permissions of the curator manipulating credentials.
|
||||
"""
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.user import TestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def test_credential_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Creating a curator
|
||||
curator: TestUser = UserManager.create(name="curator")
|
||||
|
||||
# Creating a user group
|
||||
user_group_1 = UserGroupManager.create(
|
||||
name="user_group_1",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
# setting the user as a curator for the user group
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=user_group_1,
|
||||
user_to_set_as_curator=curator,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Creating another user group that the user is not a curator of
|
||||
user_group_2 = UserGroupManager.create(
|
||||
name="user_group_2",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# END OF HAPPY PATH
|
||||
|
||||
"""Tests for things Curators should not be able to do"""
|
||||
|
||||
# Curators should not be able to create a public credential
|
||||
with pytest.raises(HTTPError):
|
||||
CredentialManager.create(
|
||||
name="invalid_credential_1",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
curator_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should not be able to create a credential for a user group they are not a curator of
|
||||
with pytest.raises(HTTPError):
|
||||
CredentialManager.create(
|
||||
name="invalid_credential_2",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id, user_group_2.id],
|
||||
curator_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
"""Tests for things Curators should be able to do"""
|
||||
# Curators should be able to create a private credential for a user group they are a curator of
|
||||
valid_credential = CredentialManager.create(
|
||||
name="valid_credential",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
curator_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Verify the created credential
|
||||
CredentialManager.verify(
|
||||
credential=valid_credential,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Test editing the credential
|
||||
valid_credential.name = "updated_valid_credential"
|
||||
CredentialManager.edit(valid_credential, user_performing_action=curator)
|
||||
|
||||
# Verify the edit
|
||||
CredentialManager.verify(
|
||||
credential=valid_credential,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Test deleting the credential
|
||||
CredentialManager.delete(valid_credential, user_performing_action=curator)
|
||||
|
||||
# Verify the deletion
|
||||
CredentialManager.verify(
|
||||
credential=valid_credential,
|
||||
verify_deleted=True,
|
||||
user_performing_action=curator,
|
||||
)
|
@ -0,0 +1,186 @@
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document_set import DocumentSetManager
|
||||
from tests.integration.common_utils.managers.user import TestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def test_doc_set_permissions_setup(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Creating a second user (curator)
|
||||
curator: TestUser = UserManager.create(name="curator")
|
||||
|
||||
# Creating the first user group
|
||||
user_group_1 = UserGroupManager.create(
|
||||
name="curated_user_group",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# Setting the curator as a curator for the first user group
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=user_group_1,
|
||||
user_to_set_as_curator=curator,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Creating a second user group
|
||||
user_group_2 = UserGroupManager.create(
|
||||
name="uncurated_user_group",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# Admin creates a cc_pair
|
||||
private_cc_pair = CCPairManager.create_from_scratch(
|
||||
is_public=False,
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Admin creates a public cc_pair
|
||||
public_cc_pair = CCPairManager.create_from_scratch(
|
||||
is_public=True,
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# END OF HAPPY PATH
|
||||
|
||||
"""Tests for things Curators/Admins should not be able to do"""
|
||||
|
||||
# Test that curator cannot create a document set for the group they don't curate
|
||||
with pytest.raises(HTTPError):
|
||||
DocumentSetManager.create(
|
||||
name="Invalid Document Set 1",
|
||||
groups=[user_group_2.id],
|
||||
cc_pair_ids=[public_cc_pair.id],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Test that curator cannot create a document set attached to both groups
|
||||
with pytest.raises(HTTPError):
|
||||
DocumentSetManager.create(
|
||||
name="Invalid Document Set 2",
|
||||
is_public=False,
|
||||
cc_pair_ids=[public_cc_pair.id],
|
||||
groups=[user_group_1.id, user_group_2.id],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Test that curator cannot create a document set with no groups
|
||||
with pytest.raises(HTTPError):
|
||||
DocumentSetManager.create(
|
||||
name="Invalid Document Set 3",
|
||||
is_public=False,
|
||||
cc_pair_ids=[public_cc_pair.id],
|
||||
groups=[],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Test that curator cannot create a document set with no cc_pairs
|
||||
with pytest.raises(HTTPError):
|
||||
DocumentSetManager.create(
|
||||
name="Invalid Document Set 4",
|
||||
is_public=False,
|
||||
cc_pair_ids=[],
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Test that admin cannot create a document set with no cc_pairs
|
||||
with pytest.raises(HTTPError):
|
||||
DocumentSetManager.create(
|
||||
name="Invalid Document Set 4",
|
||||
is_public=False,
|
||||
cc_pair_ids=[],
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
"""Tests for things Curators should be able to do"""
|
||||
# Test that curator can create a document set for the group they curate
|
||||
valid_doc_set = DocumentSetManager.create(
|
||||
name="Valid Document Set",
|
||||
is_public=False,
|
||||
cc_pair_ids=[public_cc_pair.id],
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Verify that the valid document set was created
|
||||
DocumentSetManager.verify(
|
||||
document_set=valid_doc_set,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify that only one document set exists
|
||||
all_doc_sets = DocumentSetManager.get_all(user_performing_action=admin_user)
|
||||
assert len(all_doc_sets) == 1
|
||||
|
||||
# Add the private_cc_pair to the doc set on our end for later comparison
|
||||
valid_doc_set.cc_pair_ids.append(private_cc_pair.id)
|
||||
|
||||
# Confirm the curator can't add the private_cc_pair to the doc set
|
||||
with pytest.raises(HTTPError):
|
||||
DocumentSetManager.edit(
|
||||
document_set=valid_doc_set,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
# Confirm the admin can't add the private_cc_pair to the doc set
|
||||
with pytest.raises(HTTPError):
|
||||
DocumentSetManager.edit(
|
||||
document_set=valid_doc_set,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify the document set has not been updated in the db
|
||||
with pytest.raises(ValueError):
|
||||
DocumentSetManager.verify(
|
||||
document_set=valid_doc_set,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Add the private_cc_pair to the user group on our end for later comparison
|
||||
user_group_1.cc_pair_ids.append(private_cc_pair.id)
|
||||
|
||||
# Admin adds the cc_pair to the group the curator curates
|
||||
UserGroupManager.edit(
|
||||
user_group=user_group_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Confirm the curator can now add the cc_pair to the doc set
|
||||
DocumentSetManager.edit(
|
||||
document_set=valid_doc_set,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
DocumentSetManager.wait_for_sync(
|
||||
document_sets_to_check=[valid_doc_set], user_performing_action=admin_user
|
||||
)
|
||||
# Verify the updated document set
|
||||
DocumentSetManager.verify(
|
||||
document_set=valid_doc_set,
|
||||
user_performing_action=admin_user,
|
||||
)
|
@ -0,0 +1,93 @@
|
||||
"""
|
||||
This file tests the ability of different user types to set the role of other users.
|
||||
"""
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from danswer.db.models import UserRole
|
||||
from tests.integration.common_utils.managers.user import TestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def test_user_role_setting_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
assert UserManager.verify_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Creating a basic user
|
||||
basic_user: TestUser = UserManager.create(name="basic_user")
|
||||
assert UserManager.verify_role(basic_user, UserRole.BASIC)
|
||||
|
||||
# Creating a curator
|
||||
curator: TestUser = UserManager.create(name="curator")
|
||||
assert UserManager.verify_role(curator, UserRole.BASIC)
|
||||
|
||||
# Creating a curator without adding to a group should not work
|
||||
with pytest.raises(HTTPError):
|
||||
UserManager.set_role(
|
||||
user_to_set=curator,
|
||||
target_role=UserRole.CURATOR,
|
||||
user_to_perform_action=admin_user,
|
||||
)
|
||||
|
||||
global_curator: TestUser = UserManager.create(name="global_curator")
|
||||
assert UserManager.verify_role(global_curator, UserRole.BASIC)
|
||||
|
||||
# Setting the role of a global curator should not work for a basic user
|
||||
with pytest.raises(HTTPError):
|
||||
UserManager.set_role(
|
||||
user_to_set=global_curator,
|
||||
target_role=UserRole.GLOBAL_CURATOR,
|
||||
user_to_perform_action=basic_user,
|
||||
)
|
||||
|
||||
# Setting the role of a global curator should work for an admin user
|
||||
UserManager.set_role(
|
||||
user_to_set=global_curator,
|
||||
target_role=UserRole.GLOBAL_CURATOR,
|
||||
user_to_perform_action=admin_user,
|
||||
)
|
||||
assert UserManager.verify_role(global_curator, UserRole.GLOBAL_CURATOR)
|
||||
|
||||
# Setting the role of a global curator should not work for an invalid curator
|
||||
with pytest.raises(HTTPError):
|
||||
UserManager.set_role(
|
||||
user_to_set=global_curator,
|
||||
target_role=UserRole.BASIC,
|
||||
user_to_perform_action=global_curator,
|
||||
)
|
||||
assert UserManager.verify_role(global_curator, UserRole.GLOBAL_CURATOR)
|
||||
|
||||
# Creating a user group
|
||||
user_group_1 = UserGroupManager.create(
|
||||
name="user_group_1",
|
||||
user_ids=[],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# This should fail because the curator is not in the user group
|
||||
with pytest.raises(HTTPError):
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=user_group_1,
|
||||
user_to_set_as_curator=curator,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Adding the curator to the user group
|
||||
user_group_1.user_ids = [curator.id]
|
||||
UserGroupManager.edit(user_group=user_group_1, user_performing_action=admin_user)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# This should work because the curator is in the user group
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=user_group_1,
|
||||
user_to_set_as_curator=curator,
|
||||
user_performing_action=admin_user,
|
||||
)
|
@ -0,0 +1,86 @@
|
||||
"""
|
||||
This test tests the happy path for curator permissions
|
||||
"""
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.user import TestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def test_whole_curator_flow(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
assert UserManager.verify_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Creating a curator
|
||||
curator: TestUser = UserManager.create(name="curator")
|
||||
|
||||
# Creating a user group
|
||||
user_group_1 = UserGroupManager.create(
|
||||
name="user_group_1",
|
||||
user_ids=[curator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1], user_performing_action=admin_user
|
||||
)
|
||||
# Making curator a curator of user_group_1
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=user_group_1,
|
||||
user_to_set_as_curator=curator,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert UserManager.verify_role(curator, UserRole.CURATOR)
|
||||
|
||||
# Creating a credential as curator
|
||||
test_credential = CredentialManager.create(
|
||||
name="curator_test_credential",
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=False,
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Creating a connector as curator
|
||||
test_connector = ConnectorManager.create(
|
||||
name="curator_test_connector",
|
||||
source=DocumentSource.FILE,
|
||||
is_public=False,
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Test editing the connector
|
||||
test_connector.name = "updated_test_connector"
|
||||
ConnectorManager.edit(connector=test_connector, user_performing_action=curator)
|
||||
|
||||
# Creating a CC pair as curator
|
||||
test_cc_pair = CCPairManager.create(
|
||||
connector_id=test_connector.id,
|
||||
credential_id=test_credential.id,
|
||||
name="curator_test_cc_pair",
|
||||
groups=[user_group_1.id],
|
||||
is_public=False,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
CCPairManager.verify(cc_pair=test_cc_pair, user_performing_action=admin_user)
|
||||
|
||||
# Verify that the curator can pause and unpause the CC pair
|
||||
CCPairManager.pause_cc_pair(cc_pair=test_cc_pair, user_performing_action=curator)
|
||||
|
||||
# Verify that the curator can delete the CC pair
|
||||
CCPairManager.delete(cc_pair=test_cc_pair, user_performing_action=curator)
|
||||
CCPairManager.wait_for_deletion_completion(user_performing_action=curator)
|
||||
|
||||
# Verify that the CC pair has been deleted
|
||||
CCPairManager.verify(
|
||||
cc_pair=test_cc_pair,
|
||||
verify_deleted=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
@ -73,7 +73,7 @@ export function updateCredential(credentialId: number, newDetails: any) {
|
||||
([key, value]) => key !== "name" && value !== ""
|
||||
)
|
||||
);
|
||||
return fetch(`/api/manage/admin/credentials/${credentialId}`, {
|
||||
return fetch(`/api/manage/admin/credential/${credentialId}`, {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
@ -86,7 +86,7 @@ export function updateCredential(credentialId: number, newDetails: any) {
|
||||
}
|
||||
|
||||
export function swapCredential(newCredentialId: number, connectorId: number) {
|
||||
return fetch(`/api/manage/admin/credentials/swap`, {
|
||||
return fetch(`/api/manage/admin/credential/swap`, {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
|
Loading…
x
Reference in New Issue
Block a user