Integration tests (#2256)

* initial commit

* almost done

* finished 3 tests

* minor refactor

* built out initial permisison tests

* reworked test_deletion

* removed logging

* all original tests have been converted

* renamed user_groups to user_group

* mypy

* added test for doc set permissions

* unified naming for manager methods

* Refactored models and added new deletion test

* minor additions

* better logging+fixed input variables

* commented out failed tests

* Added readme

* readme update

* Added auth to IT

set auth_type to basic and require_email_verification to false

* Update run-it.yml

* used verify and added to readme

* added api key manager
This commit is contained in:
hagen-danswer 2024-09-01 15:21:00 -07:00 committed by GitHub
parent 634de83d72
commit 8d443ada5b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 2890 additions and 612 deletions

View File

@ -92,6 +92,8 @@ jobs:
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
IMAGE_TAG=it \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build
id: start_docker

View File

@ -67,23 +67,6 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
def validate_curator_request(groups: list | None, is_public: bool) -> None:
if is_public:
detail = "Curators cannot create public objects"
logger.error(detail)
raise HTTPException(
status_code=401,
detail=detail,
)
if not groups:
detail = "Curators must specify 1+ groups"
logger.error(detail)
raise HTTPException(
status_code=401,
detail=detail,
)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True

View File

@ -334,9 +334,13 @@ def add_credential_to_connector(
raise HTTPException(status_code=404, detail="Connector does not exist")
if credential is None:
error_msg = (
f"Credential {credential_id} does not exist or does not belong to user"
)
logger.error(error_msg)
raise HTTPException(
status_code=401,
detail="Credential does not exist or does not belong to user",
detail=error_msg,
)
existing_association = (
@ -350,7 +354,7 @@ def add_credential_to_connector(
if existing_association is not None:
return StatusResponse(
success=False,
message=f"Connector already has Credential {credential_id}",
message=f"Connector {connector_id} already has Credential {credential_id}",
data=connector_id,
)
@ -374,8 +378,8 @@ def add_credential_to_connector(
db_session.commit()
return StatusResponse(
success=False,
message=f"Connector already has Credential {credential_id}",
success=True,
message=f"Creating new association between Connector {connector_id} and Credential {credential_id}",
data=association.id,
)

View File

@ -1,7 +1,6 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
@ -21,12 +20,13 @@ from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import get_index_attempts_for_connector
from danswer.db.models import User
from danswer.db.models import UserRole
from danswer.server.documents.models import CCPairFullInfo
from danswer.server.documents.models import CCStatusUpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairMetadata
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.user_group import validate_user_creation_permissions
logger = setup_logger()
@ -84,10 +84,6 @@ def get_cc_pair_full_info(
)
class CCStatusUpdateRequest(BaseModel):
status: ConnectorCredentialPairStatus
@router.put("/admin/cc-pair/{cc_pair_id}/status")
def update_cc_pair_status(
cc_pair_id: int,
@ -157,11 +153,12 @@ def associate_credential_to_connector(
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
if user and user.role != UserRole.ADMIN and metadata.is_public:
raise HTTPException(
status_code=400,
detail="Public connections cannot be created by non-admin users",
)
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=metadata.groups,
object_is_public=metadata.is_public,
)
try:
response = add_credential_to_connector(

View File

@ -75,7 +75,6 @@ from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl
from danswer.server.documents.models import ConnectorBase
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorIndexingStatus
from danswer.server.documents.models import ConnectorSnapshot
@ -93,6 +92,7 @@ from danswer.server.documents.models import ObjectCreationIdResponse
from danswer.server.documents.models import RunConnectorRequest
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.user_group import validate_user_creation_permissions
logger = setup_logger()
@ -514,35 +514,6 @@ def _validate_connector_allowed(source: DocumentSource) -> None:
)
def _check_connector_permissions(
connector_data: ConnectorUpdateRequest, user: User | None
) -> ConnectorBase:
"""
This is not a proper permission check, but this should prevent curators creating bad situations
until a long-term solution is implemented (Replacing CC pairs/Connectors with Connections)
"""
if user and user.role != UserRole.ADMIN:
if connector_data.is_public:
raise HTTPException(
status_code=400,
detail="Public connectors can only be created by admins",
)
if not connector_data.groups:
raise HTTPException(
status_code=400,
detail="Connectors created by curators must have groups",
)
return ConnectorBase(
name=connector_data.name,
source=connector_data.source,
input_type=connector_data.input_type,
connector_specific_config=connector_data.connector_specific_config,
refresh_freq=connector_data.refresh_freq,
prune_freq=connector_data.prune_freq,
indexing_start=connector_data.indexing_start,
)
@router.post("/admin/connector")
def create_connector_from_model(
connector_data: ConnectorUpdateRequest,
@ -551,13 +522,19 @@ def create_connector_from_model(
) -> ObjectCreationIdResponse:
try:
_validate_connector_allowed(connector_data.source)
connector_base = _check_connector_permissions(connector_data, user)
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=connector_data.groups,
object_is_public=connector_data.is_public,
)
connector_base = connector_data.to_connector_base()
return create_connector(
db_session=db_session,
connector_data=connector_base,
)
except ValueError as e:
logger.error(f"Error creating connector: {e}")
raise HTTPException(status_code=400, detail=str(e))
@ -608,12 +585,18 @@ def create_connector_with_mock_credential(
def update_connector_from_model(
connector_id: int,
connector_data: ConnectorUpdateRequest,
user: User = Depends(current_admin_user),
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> ConnectorSnapshot | StatusResponse[int]:
try:
_validate_connector_allowed(connector_data.source)
connector_base = _check_connector_permissions(connector_data, user)
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=connector_data.groups,
object_is_public=connector_data.is_public,
)
connector_base = connector_data.to_connector_base()
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@ -643,7 +626,7 @@ def update_connector_from_model(
@router.delete("/admin/connector/{connector_id}", response_model=StatusResponse[int])
def delete_connector_by_id(
connector_id: int,
_: User = Depends(current_admin_user),
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
try:

View File

@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import validate_curator_request
from danswer.db.credentials import alter_credential
from danswer.db.credentials import create_credential
from danswer.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
@ -20,7 +19,6 @@ from danswer.db.credentials import update_credential
from danswer.db.engine import get_session
from danswer.db.models import DocumentSource
from danswer.db.models import User
from danswer.db.models import UserRole
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import CredentialDataUpdateRequest
from danswer.server.documents.models import CredentialSnapshot
@ -28,6 +26,7 @@ from danswer.server.documents.models import CredentialSwapRequest
from danswer.server.documents.models import ObjectCreationIdResponse
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.user_group import validate_user_creation_permissions
logger = setup_logger()
@ -80,7 +79,7 @@ def get_cc_source_full_info(
]
@router.get("/credentials/{id}")
@router.get("/credential/{id}")
def list_credentials_by_id(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
@ -105,7 +104,7 @@ def delete_credential_by_id_admin(
)
@router.put("/admin/credentials/swap")
@router.put("/admin/credential/swap")
def swap_credentials_for_connector(
credential_swap_req: CredentialSwapRequest,
user: User | None = Depends(current_user),
@ -131,14 +130,12 @@ def create_credential_from_model(
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> ObjectCreationIdResponse:
if (
user
and user.role != UserRole.ADMIN
and not _ignore_credential_permissions(credential_info.source)
):
validate_curator_request(
groups=credential_info.groups,
is_public=credential_info.curator_public,
if not _ignore_credential_permissions(credential_info.source):
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=credential_info.groups,
object_is_public=credential_info.curator_public,
)
credential = create_credential(credential_info, user, db_session)
@ -179,7 +176,7 @@ def get_credential_by_id(
return CredentialSnapshot.from_credential_db_model(credential)
@router.put("/admin/credentials/{credential_id}")
@router.put("/admin/credential/{credential_id}")
def update_credential_data(
credential_id: int,
credential_update: CredentialDataUpdateRequest,

View File

@ -48,9 +48,12 @@ class ConnectorBase(BaseModel):
class ConnectorUpdateRequest(ConnectorBase):
is_public: bool | None = None
is_public: bool = True
groups: list[int] = Field(default_factory=list)
def to_connector_base(self) -> ConnectorBase:
return ConnectorBase(**self.model_dump(exclude={"is_public", "groups"}))
class ConnectorSnapshot(ConnectorBase):
id: int
@ -103,11 +106,6 @@ class CredentialSnapshot(CredentialBase):
user_id: UUID | None
time_created: datetime
time_updated: datetime
name: str | None
source: DocumentSource
credential_json: dict[str, Any]
admin_public: bool
curator_public: bool
@classmethod
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
@ -261,6 +259,10 @@ class ConnectorCredentialPairMetadata(BaseModel):
groups: list[int] = Field(default_factory=list)
class CCStatusUpdateRequest(BaseModel):
status: ConnectorCredentialPairStatus
class ConnectorCredentialPairDescriptor(BaseModel):
id: int
name: str | None = None

View File

@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import validate_curator_request
from danswer.db.document_set import check_document_sets_are_public
from danswer.db.document_set import fetch_all_document_sets_for_user
from danswer.db.document_set import insert_document_set
@ -14,12 +13,12 @@ from danswer.db.document_set import mark_document_set_as_to_be_deleted
from danswer.db.document_set import update_document_set
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.models import UserRole
from danswer.server.features.document_set.models import CheckDocSetPublicRequest
from danswer.server.features.document_set.models import CheckDocSetPublicResponse
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from danswer.server.features.document_set.models import DocumentSetUpdateRequest
from ee.danswer.db.user_group import validate_user_creation_permissions
router = APIRouter(prefix="/manage")
@ -31,11 +30,12 @@ def create_document_set(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> int:
if user and user.role != UserRole.ADMIN:
validate_curator_request(
groups=document_set_creation_request.groups,
is_public=document_set_creation_request.is_public,
)
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=document_set_creation_request.groups,
object_is_public=document_set_creation_request.is_public,
)
try:
document_set_db_model, _ = insert_document_set(
document_set_creation_request=document_set_creation_request,
@ -53,11 +53,12 @@ def patch_document_set(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
if user and user.role != UserRole.ADMIN:
validate_curator_request(
groups=document_set_update_request.groups,
is_public=document_set_update_request.is_public,
)
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=document_set_update_request.groups,
object_is_public=document_set_update_request.is_public,
)
try:
update_document_set(
document_set_update_request=document_set_update_request,

View File

@ -69,7 +69,7 @@ def set_user_role(
if user_role_update_request.new_role == UserRole.CURATOR:
raise HTTPException(
status_code=400,
status_code=402,
detail="Curator role must be set via the User Group Menu",
)
@ -78,7 +78,7 @@ def set_user_role(
if current_user.id == user_to_update.id:
raise HTTPException(
status_code=400,
status_code=402,
detail="An admin cannot demote themselves from admin role!",
)

View File

@ -2,6 +2,7 @@ from collections.abc import Sequence
from operator import and_
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import select
@ -30,6 +31,50 @@ from ee.danswer.server.user_group.models import UserGroupUpdate
logger = setup_logger()
def validate_user_creation_permissions(
db_session: Session,
user: User | None,
target_group_ids: list[int] | None,
object_is_public: bool | None,
) -> None:
"""
All admin actions are allowed.
Prevents non-admins from creating/editing:
- public objects
- objects with no groups
- objects that belong to a group they don't curate
"""
if not user or user.role == UserRole.ADMIN:
return
if object_is_public:
detail = "User does not have permission to create public credentials"
logger.error(detail)
raise HTTPException(
status_code=402,
detail=detail,
)
if not target_group_ids:
detail = "Curators must specify 1+ groups"
logger.error(detail)
raise HTTPException(
status_code=402,
detail=detail,
)
user_curated_groups = fetch_user_groups_for_user(
db_session=db_session, user_id=user.id, only_curator_groups=True
)
user_curated_group_ids = set([group.id for group in user_curated_groups])
target_group_ids_set = set(target_group_ids)
if not target_group_ids_set.issubset(user_curated_group_ids):
detail = "Curators cannot control groups they don't curate"
logger.error(detail)
raise HTTPException(
status_code=402,
detail=detail,
)
def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
return db_session.scalar(stmt)

View File

@ -9,6 +9,7 @@ from danswer.auth.users import current_curator_or_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.models import UserRole
from danswer.utils.logger import setup_logger
from ee.danswer.db.user_group import fetch_user_groups
from ee.danswer.db.user_group import fetch_user_groups_for_user
from ee.danswer.db.user_group import insert_user_group
@ -20,6 +21,8 @@ from ee.danswer.server.user_group.models import UserGroup
from ee.danswer.server.user_group.models import UserGroupCreate
from ee.danswer.server.user_group.models import UserGroupUpdate
logger = setup_logger()
router = APIRouter(prefix="/manage")
@ -90,6 +93,7 @@ def set_user_curator(
set_curator_request=set_curator_request,
)
except ValueError as e:
logger.error(f"Error setting user curator: {e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@ -0,0 +1,70 @@
# Integration Tests
## General Testing Overview
The integration tests are designed with a "manager" class and a "test" class for each type of object being manipulated (e.g., user, persona, credential):
- **Manager Class**: Contains methods for each type of API call. Responsible for creating, deleting, and verifying the existence of an entity.
- **Test Class**: Stores data for each entity being tested. This is our "expected state" of the object.
The idea is that each test can use the manager class to create (.create()) a "test_" object. It can then perform an operation on the object (e.g., send a request to the API) and then check if the "test_" object is in the expected state by using the manager class (.verify()) function.
## Instructions for Running Integration Tests Locally
1. Launch danswer (using Docker or running with a debugger), ensuring the API server is running on port 8080.
a. If you'd like to set environment variables, you can do so by creating a `.env` file in the danswer/backend/tests/integration/ directory.
2. Navigate to `danswer/backend`.
3. Run the following command in the terminal:
```sh
pytest -s tests/integration/tests/
```
or to run all tests in a file:
```sh
pytest -s tests/integration/tests/path_to/test_file.py
```
or to run a single test:
```sh
pytest -s tests/integration/tests/path_to/test_file.py::test_function_name
```
## Guidelines for Writing Integration Tests
- As authentication is currently required for all tests, each test should start by creating a user.
- Each test should ideally focus on a single API flow.
- The test writer should try to consider failure cases and edge cases for the flow and write the tests to check for these cases.
- Every step of the test should be commented describing what is being done and what the expected behavior is.
- A summary of the test should be given at the top of the test function as well!
- When writing new tests, manager classes, manager functions, and test classes, try to copy the style of the other ones that have already been written.
- Be careful for scope creep!
- No need to overcomplicate every test by verifying after every single API call so long as the case you would be verifying is covered elsewhere (ideally in a test focused on covering that case).
- An example of this is: Creating an admin user is done at the beginning of nearly every test, but we only need to verify that the user is actually an admin in the test focused on checking admin permissions. For every other test, we can just create the admin user and assume that the permissions are working as expected.
## Current Testing Limitations
### Test coverage
- All tests are probably not as high coverage as they could be.
- The "connector" tests in particular are super bare bones because we will be reworking connector/cc_pair sometime soon.
- Global Curator role is not thoroughly tested.
- No auth is not tested at all.
### Failure checking
- While we test expected auth failures, we only check that it failed at all.
- We dont check that the return codes are what we expect.
- This means that a test could be failing for a different reason than expected.
- We should ensure that the proper codes are being returned for each failure case.
- We should also query the db after each failure to ensure that the db is in the expected state.
### Scope/focus
- The tests may be scoped sub-optimally.
- The scoping of each test may be overlapping.
## Current Testing Coverage
The current testing coverage should be checked by reading the comments at the top of each test file.
## TODO: Testing Coverage
- Persona permissions testing
- Read only (and/or basic) user permissions
- Ensuring proper permission enforcement using the chat/doc_search endpoints
- No auth
## Ideas for integration testing design
### Combine the "test" and "manager" classes
This could make test writing a bit cleaner by preventing test writers from having to pass around objects into functions that the objects have a 1:1 relationship with.
### Rework VespaClient
Right now, its used a fixture and has to be passed around between manager classes.
Could just be built where its used

View File

@ -1,114 +0,0 @@
import uuid
from typing import cast
import requests
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.db.enums import ConnectorCredentialPairStatus
from tests.integration.common_utils.constants import API_SERVER_URL
class ConnectorCreationDetails(BaseModel):
connector_id: int
credential_id: int
cc_pair_id: int
class ConnectorClient:
@staticmethod
def create_connector(
name_prefix: str = "test_connector", credential_id: int | None = None
) -> ConnectorCreationDetails:
unique_id = uuid.uuid4()
connector_name = f"{name_prefix}_{unique_id}"
connector_data = {
"name": connector_name,
"source": DocumentSource.NOT_APPLICABLE,
"input_type": "load_state",
"connector_specific_config": {},
"refresh_freq": 60,
"disabled": True,
}
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector",
json=connector_data,
)
response.raise_for_status()
connector_id = response.json()["id"]
# associate the credential with the connector
if not credential_id:
print("ID not specified, creating new credential")
# Create a new credential
credential_data = {
"credential_json": {},
"admin_public": True,
"source": DocumentSource.NOT_APPLICABLE,
}
response = requests.post(
f"{API_SERVER_URL}/manage/credential",
json=credential_data,
)
response.raise_for_status()
credential_id = cast(int, response.json()["id"])
cc_pair_metadata = {"name": f"test_cc_pair_{unique_id}", "is_public": True}
response = requests.put(
f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}",
json=cc_pair_metadata,
)
response.raise_for_status()
# fetch the conenector credential pair id using the indexing status API
response = requests.get(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status"
)
response.raise_for_status()
indexing_statuses = response.json()
cc_pair_id = None
for status in indexing_statuses:
if (
status["connector"]["id"] == connector_id
and status["credential"]["id"] == credential_id
):
cc_pair_id = status["cc_pair_id"]
break
if cc_pair_id is None:
raise ValueError("Could not find the connector credential pair id")
print(
f"Created connector with connector_id: {connector_id}, credential_id: {credential_id}, cc_pair_id: {cc_pair_id}"
)
return ConnectorCreationDetails(
connector_id=int(connector_id),
credential_id=int(credential_id),
cc_pair_id=int(cc_pair_id),
)
@staticmethod
def update_connector_status(
cc_pair_id: int, status: ConnectorCredentialPairStatus
) -> None:
response = requests.put(
f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/status",
json={"status": status},
)
response.raise_for_status()
@staticmethod
def delete_connector(connector_id: int, credential_id: int) -> None:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/deletion-attempt",
json={"connector_id": connector_id, "credential_id": credential_id},
)
response.raise_for_status()
@staticmethod
def get_connectors() -> list[dict]:
response = requests.get(f"{API_SERVER_URL}/manage/connector")
response.raise_for_status()
return response.json()

View File

@ -5,3 +5,7 @@ API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
MAX_DELAY = 30
GENERAL_HEADERS = {"Content-Type": "application/json"}
NUM_DOCS = 5

View File

@ -1,30 +0,0 @@
from typing import cast
import requests
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from tests.integration.common_utils.constants import API_SERVER_URL
class DocumentSetClient:
@staticmethod
def create_document_set(
doc_set_creation_request: DocumentSetCreationRequest,
) -> int:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_creation_request.model_dump(),
)
response.raise_for_status()
return cast(int, response.json())
@staticmethod
def fetch_document_sets() -> list[DocumentSet]:
response = requests.get(f"{API_SERVER_URL}/manage/document-set")
response.raise_for_status()
document_sets = [
DocumentSet.parse_obj(doc_set_data) for doc_set_data in response.json()
]
return document_sets

View File

@ -1,62 +1,88 @@
import os
from typing import cast
from uuid import uuid4
import requests
from pydantic import BaseModel
from pydantic import PrivateAttr
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestLLMProvider
from tests.integration.common_utils.test_models import TestUser
class LLMProvider(BaseModel):
provider: str
api_key: str
default_model_name: str
api_base: str | None = None
api_version: str | None = None
is_default: bool = True
class LLMProviderManager:
@staticmethod
def create(
name: str | None = None,
provider: str | None = None,
api_key: str | None = None,
default_model_name: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
groups: list[int] | None = None,
is_public: bool | None = None,
user_performing_action: TestUser | None = None,
) -> TestLLMProvider:
print("Seeding LLM Providers...")
# only populated after creation
_provider_id: int | None = PrivateAttr()
def create(self) -> int:
llm_provider = LLMProviderUpsertRequest(
name=self.provider,
provider=self.provider,
default_model_name=self.default_model_name,
api_key=self.api_key,
api_base=self.api_base,
api_version=self.api_version,
name=name or f"test-provider-{uuid4()}",
provider=provider or "openai",
default_model_name=default_model_name or "gpt-4o-mini",
api_key=api_key or os.environ["OPENAI_API_KEY"],
api_base=api_base,
api_version=api_version,
custom_config=None,
fast_default_model_name=None,
is_public=True,
groups=[],
fast_default_model_name=default_model_name or "gpt-4o-mini",
is_public=is_public or True,
groups=groups or [],
display_model_names=None,
model_names=None,
)
response = requests.put(
llm_response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
json=llm_provider.dict(),
json=llm_provider.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
llm_response.raise_for_status()
response_data = llm_response.json()
result_llm = TestLLMProvider(
id=response_data["id"],
name=response_data["name"],
provider=response_data["provider"],
api_key=response_data["api_key"],
default_model_name=response_data["default_model_name"],
is_public=response_data["is_public"],
groups=response_data["groups"],
api_base=response_data["api_base"],
api_version=response_data["api_version"],
)
response.raise_for_status()
self._provider_id = cast(int, response.json()["id"])
return self._provider_id
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
set_default_response.raise_for_status()
def delete(self) -> None:
return result_llm
@staticmethod
def delete(
llm_provider: TestLLMProvider,
user_performing_action: TestUser | None = None,
) -> bool:
if not llm_provider.id:
raise ValueError("LLM Provider ID is required to delete a provider")
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{self._provider_id}"
f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
def seed_default_openai_provider() -> LLMProvider:
llm = LLMProvider(
provider="openai",
default_model_name="gpt-4o-mini",
api_key=os.environ["OPENAI_API_KEY"],
)
llm.create()
return llm
return True

View File

@ -0,0 +1,92 @@
from uuid import uuid4
import requests
from danswer.db.models import UserRole
from ee.danswer.server.api_key.models import APIKeyArgs
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestAPIKey
from tests.integration.common_utils.test_models import TestUser
class APIKeyManager:
@staticmethod
def create(
name: str | None = None,
api_key_role: UserRole = UserRole.ADMIN,
user_performing_action: TestUser | None = None,
) -> TestAPIKey:
name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}"
api_key_request = APIKeyArgs(
name=name,
role=api_key_role,
)
api_key_response = requests.post(
f"{API_SERVER_URL}/admin/api-key",
json=api_key_request.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
api_key_response.raise_for_status()
api_key = api_key_response.json()
result_api_key = TestAPIKey(
api_key_id=api_key["api_key_id"],
api_key_display=api_key["api_key_display"],
api_key=api_key["api_key"],
api_key_name=name,
api_key_role=api_key_role,
user_id=api_key["user_id"],
headers=GENERAL_HEADERS,
)
result_api_key.headers["Authorization"] = f"Bearer {result_api_key.api_key}"
return result_api_key
@staticmethod
def delete(
api_key: TestAPIKey,
user_performing_action: TestUser | None = None,
) -> None:
api_key_response = requests.delete(
f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
api_key_response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[TestAPIKey]:
api_key_response = requests.get(
f"{API_SERVER_URL}/admin/api-key",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
api_key_response.raise_for_status()
return [TestAPIKey(**api_key) for api_key in api_key_response.json()]
@staticmethod
def verify(
api_key: TestAPIKey,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
) -> None:
retrieved_keys = APIKeyManager.get_all(
user_performing_action=user_performing_action
)
for key in retrieved_keys:
if key.api_key_id == api_key.api_key_id:
if verify_deleted:
raise ValueError("API Key found when it should have been deleted")
if (
key.api_key_name == api_key.api_key_name
and key.api_key_role == api_key.api_key_role
):
return
if not verify_deleted:
raise Exception("API Key not found")

View File

@ -0,0 +1,202 @@
import time
from typing import Any
from uuid import uuid4
import requests
from danswer.connectors.models import InputType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorIndexingStatus
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.test_models import TestCCPair
from tests.integration.common_utils.test_models import TestUser
def _cc_pair_creator(
connector_id: int,
credential_id: int,
name: str | None = None,
is_public: bool = True,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}"
request = {
"name": name,
"is_public": is_public,
"groups": groups or [],
}
response = requests.put(
url=f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}",
json=request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return TestCCPair(
id=response.json()["data"],
name=name,
connector_id=connector_id,
credential_id=credential_id,
is_public=is_public,
groups=groups or [],
)
class CCPairManager:
@staticmethod
def create_from_scratch(
name: str | None = None,
is_public: bool = True,
groups: list[int] | None = None,
source: DocumentSource = DocumentSource.FILE,
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
credential_json: dict[str, Any] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
connector = ConnectorManager.create(
name=name,
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config,
is_public=is_public,
groups=groups,
user_performing_action=user_performing_action,
)
credential = CredentialManager.create(
credential_json=credential_json,
name=name,
source=source,
curator_public=is_public,
groups=groups,
user_performing_action=user_performing_action,
)
return _cc_pair_creator(
connector_id=connector.id,
credential_id=credential.id,
name=name,
is_public=is_public,
groups=groups,
user_performing_action=user_performing_action,
)
@staticmethod
def create(
connector_id: int,
credential_id: int,
name: str | None = None,
is_public: bool = True,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCCPair:
return _cc_pair_creator(
connector_id=connector_id,
credential_id=credential_id,
name=name,
is_public=is_public,
groups=groups,
user_performing_action=user_performing_action,
)
@staticmethod
def pause_cc_pair(
cc_pair: TestCCPair,
user_performing_action: TestUser | None = None,
) -> None:
result = requests.put(
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status",
json={"status": "PAUSED"},
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
result.raise_for_status()
@staticmethod
def delete(
cc_pair: TestCCPair,
user_performing_action: TestUser | None = None,
) -> None:
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
result = requests.post(
url=f"{API_SERVER_URL}/manage/admin/deletion-attempt",
json=cc_pair_identifier.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
result.raise_for_status()
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[ConnectorIndexingStatus]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/connector/indexing-status",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [ConnectorIndexingStatus(**cc_pair) for cc_pair in response.json()]
@staticmethod
def verify(
cc_pair: TestCCPair,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
) -> None:
all_cc_pairs = CCPairManager.get_all(user_performing_action)
for retrieved_cc_pair in all_cc_pairs:
if retrieved_cc_pair.cc_pair_id == cc_pair.id:
if verify_deleted:
# We assume that this check will be performed after the deletion is
# already waited for
raise ValueError(
f"CC pair {cc_pair.id} found but should be deleted"
)
if (
retrieved_cc_pair.name == cc_pair.name
and retrieved_cc_pair.connector.id == cc_pair.connector_id
and retrieved_cc_pair.credential.id == cc_pair.credential_id
and retrieved_cc_pair.public_doc == cc_pair.is_public
and set(retrieved_cc_pair.groups) == set(cc_pair.groups)
):
return
if not verify_deleted:
raise ValueError(f"CC pair {cc_pair.id} not found")
@staticmethod
def wait_for_deletion_completion(
user_performing_action: TestUser | None = None,
) -> None:
start = time.time()
while True:
cc_pairs = CCPairManager.get_all(user_performing_action)
if all(
cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING
for cc_pair in cc_pairs
):
return
if time.time() - start > MAX_DELAY:
raise TimeoutError(
f"CC pairs deletion was not completed within the {MAX_DELAY} seconds"
)
else:
print("Some CC pairs are still being deleted, waiting...")
time.sleep(2)

View File

@ -0,0 +1,124 @@
from typing import Any
from uuid import uuid4
import requests
from danswer.connectors.models import InputType
from danswer.server.documents.models import ConnectorUpdateRequest
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestConnector
from tests.integration.common_utils.test_models import TestUser
class ConnectorManager:
@staticmethod
def create(
name: str | None = None,
source: DocumentSource = DocumentSource.FILE,
input_type: InputType = InputType.LOAD_STATE,
connector_specific_config: dict[str, Any] | None = None,
is_public: bool = True,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestConnector:
name = f"{name}-connector" if name else f"test-connector-{uuid4()}"
connector_update_request = ConnectorUpdateRequest(
name=name,
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config or {},
is_public=is_public,
groups=groups or [],
)
response = requests.post(
url=f"{API_SERVER_URL}/manage/admin/connector",
json=connector_update_request.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
response_data = response.json()
return TestConnector(
id=response_data.get("id"),
name=name,
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config or {},
groups=groups,
is_public=is_public,
)
@staticmethod
def edit(
connector: TestConnector,
user_performing_action: TestUser | None = None,
) -> None:
response = requests.patch(
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
json=connector.model_dump(exclude={"id"}),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def delete(
connector: TestConnector,
user_performing_action: TestUser | None = None,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[TestConnector]:
response = requests.get(
url=f"{API_SERVER_URL}/manage/connector",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [
TestConnector(
id=conn.get("id"),
name=conn.get("name", ""),
source=conn.get("source", DocumentSource.FILE),
input_type=conn.get("input_type", InputType.LOAD_STATE),
connector_specific_config=conn.get("connector_specific_config", {}),
)
for conn in response.json()
]
@staticmethod
def get(
connector_id: int, user_performing_action: TestUser | None = None
) -> TestConnector:
response = requests.get(
url=f"{API_SERVER_URL}/manage/connector/{connector_id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
conn = response.json()
return TestConnector(
id=conn.get("id"),
name=conn.get("name", ""),
source=conn.get("source", DocumentSource.FILE),
input_type=conn.get("input_type", InputType.LOAD_STATE),
connector_specific_config=conn.get("connector_specific_config", {}),
)

View File

@ -0,0 +1,129 @@
from typing import Any
from uuid import uuid4
import requests
from danswer.server.documents.models import CredentialSnapshot
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestCredential
from tests.integration.common_utils.test_models import TestUser
class CredentialManager:
@staticmethod
def create(
credential_json: dict[str, Any] | None = None,
admin_public: bool = True,
name: str | None = None,
source: DocumentSource = DocumentSource.FILE,
curator_public: bool = True,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestCredential:
name = f"{name}-credential" if name else f"test-credential-{uuid4()}"
credential_request = {
"name": name,
"credential_json": credential_json or {},
"admin_public": admin_public,
"source": source,
"curator_public": curator_public,
"groups": groups or [],
}
response = requests.post(
url=f"{API_SERVER_URL}/manage/credential",
json=credential_request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return TestCredential(
id=response.json()["id"],
name=name,
credential_json=credential_json or {},
admin_public=admin_public,
source=source,
curator_public=curator_public,
groups=groups or [],
)
@staticmethod
def edit(
credential: TestCredential,
user_performing_action: TestUser | None = None,
) -> None:
request = credential.model_dump(include={"name", "credential_json"})
response = requests.put(
url=f"{API_SERVER_URL}/manage/admin/credential/{credential.id}",
json=request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def delete(
credential: TestCredential,
user_performing_action: TestUser | None = None,
) -> None:
response = requests.delete(
url=f"{API_SERVER_URL}/manage/credential/{credential.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def get(
credential_id: int, user_performing_action: TestUser | None = None
) -> CredentialSnapshot:
response = requests.get(
url=f"{API_SERVER_URL}/manage/credential/{credential_id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return CredentialSnapshot(**response.json())
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[CredentialSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/manage/credential",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [CredentialSnapshot(**cred) for cred in response.json()]
@staticmethod
def verify(
credential: TestCredential,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
) -> None:
all_credentials = CredentialManager.get_all(user_performing_action)
for fetched_credential in all_credentials:
if credential.id == fetched_credential.id:
if verify_deleted:
raise ValueError(
f"Credential {credential.id} found but should be deleted"
)
if (
credential.name == fetched_credential.name
and credential.admin_public == fetched_credential.admin_public
and credential.source == fetched_credential.source
and credential.curator_public == fetched_credential.curator_public
):
return
if not verify_deleted:
raise ValueError(f"Credential {credential.id} not found")

View File

@ -0,0 +1,153 @@
from uuid import uuid4
import requests
from danswer.configs.constants import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import TestAPIKey
from tests.integration.common_utils.managers.cc_pair import TestCCPair
from tests.integration.common_utils.test_models import SimpleTestDocument
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.vespa import TestVespaClient
def _verify_document_permissions(
retrieved_doc: dict,
cc_pair: TestCCPair,
doc_set_names: list[str] | None = None,
group_names: list[str] | None = None,
doc_creating_user: TestUser | None = None,
) -> None:
acl_keys = set(retrieved_doc["access_control_list"].keys())
print(f"ACL keys: {acl_keys}")
if cc_pair.is_public:
if "PUBLIC" not in acl_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} is public but"
" does not have the PUBLIC ACL key"
)
if doc_creating_user is not None:
if f"user_id:{doc_creating_user.id}" not in acl_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} was created by user"
f" {doc_creating_user.id} but does not have the user_id:{doc_creating_user.id} ACL key"
)
if group_names is not None:
expected_group_keys = {f"group:{group_name}" for group_name in group_names}
found_group_keys = {key for key in acl_keys if key.startswith("group:")}
if found_group_keys != expected_group_keys:
raise ValueError(
f"Document {retrieved_doc['document_id']} has incorrect group ACL keys. Found: {found_group_keys}, \n"
f"Expected: {expected_group_keys}"
)
if doc_set_names is not None:
found_doc_set_names = set(retrieved_doc.get("document_sets", {}).keys())
if found_doc_set_names != set(doc_set_names):
raise ValueError(
f"Document set names mismatch. \nFound: {found_doc_set_names}, \n"
f"Expected: {set(doc_set_names)}"
)
def _generate_dummy_document(document_id: str, cc_pair_id: int) -> dict:
return {
"document": {
"id": document_id,
"sections": [
{
"text": f"This is test document {document_id}",
"link": f"{document_id}",
}
],
"source": DocumentSource.NOT_APPLICABLE,
# just for testing metadata
"metadata": {"document_id": document_id},
"semantic_identifier": f"Test Document {document_id}",
"from_ingestion_api": True,
},
"cc_pair_id": cc_pair_id,
}
class DocumentManager:
@staticmethod
def seed_and_attach_docs(
cc_pair: TestCCPair,
num_docs: int = NUM_DOCS,
document_ids: list[str] | None = None,
api_key: TestAPIKey | None = None,
) -> TestCCPair:
# Use provided document_ids if available, otherwise generate random UUIDs
if document_ids is None:
document_ids = [f"test-doc-{uuid4()}" for _ in range(num_docs)]
else:
num_docs = len(document_ids)
# Create and ingest some documents
documents: list[dict] = []
for document_id in document_ids:
document = _generate_dummy_document(document_id, cc_pair.id)
documents.append(document)
response = requests.post(
f"{API_SERVER_URL}/danswer-api/ingestion",
json=document,
headers=api_key.headers if api_key else GENERAL_HEADERS,
)
response.raise_for_status()
print("Seeding completed successfully.")
cc_pair.documents = [
SimpleTestDocument(
id=document["document"]["id"],
content=document["document"]["sections"][0]["text"],
)
for document in documents
]
return cc_pair
@staticmethod
def verify(
vespa_client: TestVespaClient,
cc_pair: TestCCPair,
# If None, will not check doc sets or groups
# If empty list, will check for empty doc sets or groups
doc_set_names: list[str] | None = None,
group_names: list[str] | None = None,
doc_creating_user: TestUser | None = None,
verify_deleted: bool = False,
) -> None:
doc_ids = [document.id for document in cc_pair.documents]
retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"]
retrieved_docs = {
doc["fields"]["document_id"]: doc["fields"] for doc in retrieved_docs_dict
}
# Left this here for debugging purposes.
# import json
# for doc in retrieved_docs.values():
# printable_doc = doc.copy()
# print(printable_doc.keys())
# printable_doc.pop("embeddings")
# printable_doc.pop("title_embedding")
# print(json.dumps(printable_doc, indent=2))
for document in cc_pair.documents:
retrieved_doc = retrieved_docs.get(document.id)
if not retrieved_doc:
if not verify_deleted:
raise ValueError(f"Document not found: {document.id}")
continue
if verify_deleted:
raise ValueError(
f"Document found when it should be deleted: {document.id}"
)
_verify_document_permissions(
retrieved_doc,
cc_pair,
doc_set_names,
group_names,
doc_creating_user,
)

View File

@ -0,0 +1,171 @@
import time
from uuid import uuid4
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import TestDocumentSet
from tests.integration.common_utils.test_models import TestUser
class DocumentSetManager:
@staticmethod
def create(
name: str | None = None,
description: str | None = None,
cc_pair_ids: list[int] | None = None,
is_public: bool = True,
users: list[str] | None = None,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestDocumentSet:
if name is None:
name = f"test_doc_set_{str(uuid4())}"
doc_set_creation_request = {
"name": name,
"description": description or name,
"cc_pair_ids": cc_pair_ids or [],
"is_public": is_public,
"users": users or [],
"groups": groups or [],
}
response = requests.post(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_creation_request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return TestDocumentSet(
id=int(response.json()),
name=name,
description=description or name,
cc_pair_ids=cc_pair_ids or [],
is_public=is_public,
is_up_to_date=True,
users=users or [],
groups=groups or [],
)
@staticmethod
def edit(
document_set: TestDocumentSet,
user_performing_action: TestUser | None = None,
) -> bool:
doc_set_update_request = {
"id": document_set.id,
"description": document_set.description,
"cc_pair_ids": document_set.cc_pair_ids,
"is_public": document_set.is_public,
"users": document_set.users,
"groups": document_set.groups,
}
response = requests.patch(
f"{API_SERVER_URL}/manage/admin/document-set",
json=doc_set_update_request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return True
@staticmethod
def delete(
document_set: TestDocumentSet,
user_performing_action: TestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return True
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[TestDocumentSet]:
response = requests.get(
f"{API_SERVER_URL}/manage/document-set",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [
TestDocumentSet(
id=doc_set["id"],
name=doc_set["name"],
description=doc_set["description"],
cc_pair_ids=[
cc_pair["id"] for cc_pair in doc_set["cc_pair_descriptors"]
],
is_public=doc_set["is_public"],
is_up_to_date=doc_set["is_up_to_date"],
users=doc_set["users"],
groups=doc_set["groups"],
)
for doc_set in response.json()
]
@staticmethod
def wait_for_sync(
document_sets_to_check: list[TestDocumentSet] | None = None,
user_performing_action: TestUser | None = None,
) -> None:
# wait for document sets to be synced
start = time.time()
while True:
doc_sets = DocumentSetManager.get_all(user_performing_action)
if document_sets_to_check:
check_ids = {doc_set.id for doc_set in document_sets_to_check}
doc_set_ids = {doc_set.id for doc_set in doc_sets}
if not check_ids.issubset(doc_set_ids):
raise RuntimeError("Document set not found")
doc_sets = [doc_set for doc_set in doc_sets if doc_set.id in check_ids]
all_up_to_date = all(doc_set.is_up_to_date for doc_set in doc_sets)
if all_up_to_date:
break
if time.time() - start > MAX_DELAY:
raise TimeoutError(
f"Document sets were not synced within the {MAX_DELAY} seconds"
)
else:
print("Document sets were not synced yet, waiting...")
time.sleep(2)
@staticmethod
def verify(
document_set: TestDocumentSet,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
) -> None:
doc_sets = DocumentSetManager.get_all(user_performing_action)
for doc_set in doc_sets:
if doc_set.id == document_set.id:
if verify_deleted:
raise ValueError(
f"Document set {document_set.id} found but should have been deleted"
)
if (
doc_set.name == document_set.name
and set(doc_set.cc_pair_ids) == set(document_set.cc_pair_ids)
and doc_set.is_public == document_set.is_public
and set(doc_set.users) == set(document_set.users)
and set(doc_set.groups) == set(document_set.groups)
):
return
if not verify_deleted:
raise ValueError(f"Document set {document_set.id} not found")

View File

@ -0,0 +1,206 @@
from uuid import uuid4
import requests
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import PersonaSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestPersona
from tests.integration.common_utils.test_models import TestUser
class PersonaManager:
@staticmethod
def create(
name: str | None = None,
description: str | None = None,
num_chunks: float = 5,
llm_relevance_filter: bool = True,
is_public: bool = True,
llm_filter_extraction: bool = True,
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
prompt_ids: list[int] | None = None,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
llm_model_provider_override: str | None = None,
llm_model_version_override: str | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestPersona:
name = name or f"test-persona-{uuid4()}"
description = description or f"Description for {name}"
persona_creation_request = {
"name": name,
"description": description,
"num_chunks": num_chunks,
"llm_relevance_filter": llm_relevance_filter,
"is_public": is_public,
"llm_filter_extraction": llm_filter_extraction,
"recency_bias": recency_bias,
"prompt_ids": prompt_ids or [],
"document_set_ids": document_set_ids or [],
"tool_ids": tool_ids or [],
"llm_model_provider_override": llm_model_provider_override,
"llm_model_version_override": llm_model_version_override,
"users": users or [],
"groups": groups or [],
}
response = requests.post(
f"{API_SERVER_URL}/persona",
json=persona_creation_request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
persona_data = response.json()
return TestPersona(
id=persona_data["id"],
name=name,
description=description,
num_chunks=num_chunks,
llm_relevance_filter=llm_relevance_filter,
is_public=is_public,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
prompt_ids=prompt_ids or [],
document_set_ids=document_set_ids or [],
tool_ids=tool_ids or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
users=users or [],
groups=groups or [],
)
@staticmethod
def edit(
persona: TestPersona,
name: str | None = None,
description: str | None = None,
num_chunks: float | None = None,
llm_relevance_filter: bool | None = None,
is_public: bool | None = None,
llm_filter_extraction: bool | None = None,
recency_bias: RecencyBiasSetting | None = None,
prompt_ids: list[int] | None = None,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
llm_model_provider_override: str | None = None,
llm_model_version_override: str | None = None,
users: list[str] | None = None,
groups: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestPersona:
persona_update_request = {
"name": name or persona.name,
"description": description or persona.description,
"num_chunks": num_chunks or persona.num_chunks,
"llm_relevance_filter": llm_relevance_filter
or persona.llm_relevance_filter,
"is_public": is_public or persona.is_public,
"llm_filter_extraction": llm_filter_extraction
or persona.llm_filter_extraction,
"recency_bias": recency_bias or persona.recency_bias,
"prompt_ids": prompt_ids or persona.prompt_ids,
"document_set_ids": document_set_ids or persona.document_set_ids,
"tool_ids": tool_ids or persona.tool_ids,
"llm_model_provider_override": llm_model_provider_override
or persona.llm_model_provider_override,
"llm_model_version_override": llm_model_version_override
or persona.llm_model_version_override,
"users": users or persona.users,
"groups": groups or persona.groups,
}
response = requests.patch(
f"{API_SERVER_URL}/persona/{persona.id}",
json=persona_update_request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
updated_persona_data = response.json()
return TestPersona(
id=updated_persona_data["id"],
name=updated_persona_data["name"],
description=updated_persona_data["description"],
num_chunks=updated_persona_data["num_chunks"],
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
is_public=updated_persona_data["is_public"],
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
recency_bias=updated_persona_data["recency_bias"],
prompt_ids=updated_persona_data["prompts"],
document_set_ids=updated_persona_data["document_sets"],
tool_ids=updated_persona_data["tools"],
llm_model_provider_override=updated_persona_data[
"llm_model_provider_override"
],
llm_model_version_override=updated_persona_data[
"llm_model_version_override"
],
users=[user["email"] for user in updated_persona_data["users"]],
groups=updated_persona_data["groups"],
)
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[PersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/admin/persona",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [PersonaSnapshot(**persona) for persona in response.json()]
@staticmethod
def verify(
test_persona: TestPersona,
user_performing_action: TestUser | None = None,
) -> bool:
all_personas = PersonaManager.get_all(user_performing_action)
for persona in all_personas:
if persona.id == test_persona.id:
return (
persona.name == test_persona.name
and persona.description == test_persona.description
and persona.num_chunks == test_persona.num_chunks
and persona.llm_relevance_filter
== test_persona.llm_relevance_filter
and persona.is_public == test_persona.is_public
and persona.llm_filter_extraction
== test_persona.llm_filter_extraction
and persona.llm_model_provider_override
== test_persona.llm_model_provider_override
and persona.llm_model_version_override
== test_persona.llm_model_version_override
and set(persona.prompts) == set(test_persona.prompt_ids)
and set(persona.document_sets) == set(test_persona.document_set_ids)
and set(persona.tools) == set(test_persona.tool_ids)
and set(user.email for user in persona.users)
== set(test_persona.users)
and set(persona.groups) == set(test_persona.groups)
)
return False
@staticmethod
def delete(
persona: TestPersona,
user_performing_action: TestUser | None = None,
) -> bool:
response = requests.delete(
f"{API_SERVER_URL}/persona/{persona.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
return response.ok

View File

@ -0,0 +1,122 @@
from copy import deepcopy
from urllib.parse import urlencode
from uuid import uuid4
import requests
from danswer.db.models import UserRole
from danswer.server.manage.models import AllUsersResponse
from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import TestUser
class UserManager:
@staticmethod
def create(
name: str | None = None,
) -> TestUser:
if name is None:
name = f"test{str(uuid4())}"
email = f"{name}@test.com"
password = "test"
body = {
"email": email,
"username": email,
"password": password,
}
response = requests.post(
url=f"{API_SERVER_URL}/auth/register",
json=body,
headers=GENERAL_HEADERS,
)
response.raise_for_status()
test_user = TestUser(
id=response.json()["id"],
email=email,
password=password,
headers=deepcopy(GENERAL_HEADERS),
)
print(f"Created user {test_user.email}")
test_user.headers["Cookie"] = UserManager.login_as_user(test_user)
return test_user
@staticmethod
def login_as_user(test_user: TestUser) -> str:
data = urlencode(
{
"username": test_user.email,
"password": test_user.password,
}
)
headers = test_user.headers.copy()
headers["Content-Type"] = "application/x-www-form-urlencoded"
response = requests.post(
url=f"{API_SERVER_URL}/auth/login",
data=data,
headers=headers,
)
response.raise_for_status()
result_cookie = next(iter(response.cookies), None)
if not result_cookie:
raise Exception("Failed to login")
print(f"Logged in as {test_user.email}")
return f"{result_cookie.name}={result_cookie.value}"
@staticmethod
def verify_role(user_to_verify: TestUser, target_role: UserRole) -> bool:
response = requests.get(
url=f"{API_SERVER_URL}/me",
headers=user_to_verify.headers,
)
response.raise_for_status()
return target_role == UserRole(response.json().get("role", ""))
@staticmethod
def set_role(
user_to_set: TestUser,
target_role: UserRole,
user_to_perform_action: TestUser | None = None,
) -> None:
if user_to_perform_action is None:
user_to_perform_action = user_to_set
response = requests.patch(
url=f"{API_SERVER_URL}/manage/set-user-role",
json={"user_email": user_to_set.email, "new_role": target_role.value},
headers=user_to_perform_action.headers,
)
response.raise_for_status()
@staticmethod
def verify(user: TestUser, user_to_perform_action: TestUser | None = None) -> None:
if user_to_perform_action is None:
user_to_perform_action = user
response = requests.get(
url=f"{API_SERVER_URL}/manage/users",
headers=user_to_perform_action.headers
if user_to_perform_action
else GENERAL_HEADERS,
)
response.raise_for_status()
data = response.json()
all_users = AllUsersResponse(
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
accepted_pages=data["accepted_pages"],
invited_pages=data["invited_pages"],
)
for accepted_user in all_users.accepted:
if accepted_user.email == user.email and accepted_user.id == user.id:
return
raise ValueError(f"User {user.email} not found")

View File

@ -0,0 +1,148 @@
import time
from uuid import uuid4
import requests
from ee.danswer.server.user_group.models import UserGroup
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import TestUserGroup
class UserGroupManager:
@staticmethod
def create(
name: str | None = None,
user_ids: list[str] | None = None,
cc_pair_ids: list[int] | None = None,
user_performing_action: TestUser | None = None,
) -> TestUserGroup:
name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}"
request = {
"name": name,
"user_ids": user_ids or [],
"cc_pair_ids": cc_pair_ids or [],
}
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group",
json=request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
test_user_group = TestUserGroup(
id=response.json()["id"],
name=response.json()["name"],
user_ids=[user["id"] for user in response.json()["users"]],
cc_pair_ids=[cc_pair["id"] for cc_pair in response.json()["cc_pairs"]],
)
return test_user_group
@staticmethod
def edit(
user_group: TestUserGroup,
user_performing_action: TestUser | None = None,
) -> None:
if not user_group.id:
raise ValueError("User group has no ID")
response = requests.patch(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}",
json=user_group.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def set_curator_status(
test_user_group: TestUserGroup,
user_to_set_as_curator: TestUser,
is_curator: bool = True,
user_performing_action: TestUser | None = None,
) -> None:
if not user_to_set_as_curator.id:
raise ValueError("User has no ID")
set_curator_request = {
"user_id": user_to_set_as_curator.id,
"is_curator": is_curator,
}
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group/{test_user_group.id}/set-curator",
json=set_curator_request,
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
@staticmethod
def get_all(
user_performing_action: TestUser | None = None,
) -> list[UserGroup]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [UserGroup(**ug) for ug in response.json()]
@staticmethod
def verify(
user_group: TestUserGroup,
verify_deleted: bool = False,
user_performing_action: TestUser | None = None,
) -> None:
all_user_groups = UserGroupManager.get_all(user_performing_action)
for fetched_user_group in all_user_groups:
if user_group.id == fetched_user_group.id:
if verify_deleted:
raise ValueError(
f"User group {user_group.id} found but should be deleted"
)
fetched_cc_ids = {cc_pair.id for cc_pair in fetched_user_group.cc_pairs}
fetched_user_ids = {user.id for user in fetched_user_group.users}
user_group_cc_ids = set(user_group.cc_pair_ids)
user_group_user_ids = set(user_group.user_ids)
if (
fetched_cc_ids == user_group_cc_ids
and fetched_user_ids == user_group_user_ids
):
return
if not verify_deleted:
raise ValueError(f"User group {user_group.id} not found")
@staticmethod
def wait_for_sync(
user_groups_to_check: list[TestUserGroup] | None = None,
user_performing_action: TestUser | None = None,
) -> None:
start = time.time()
while True:
user_groups = UserGroupManager.get_all(user_performing_action)
if user_groups_to_check:
check_ids = {user_group.id for user_group in user_groups_to_check}
user_group_ids = {user_group.id for user_group in user_groups}
if not check_ids.issubset(user_group_ids):
raise RuntimeError("Document set not found")
user_groups = [
user_group
for user_group in user_groups
if user_group.id in check_ids
]
if all(ug.is_up_to_date for ug in user_groups):
return
if time.time() - start > MAX_DELAY:
raise TimeoutError(
f"User groups were not synced within the {MAX_DELAY} seconds"
)
else:
print("User groups were not synced yet, waiting...")
time.sleep(2)

View File

@ -20,7 +20,6 @@ from danswer.document_index.vespa.index import VespaIndex
from danswer.indexing.models import IndexingSetting
from danswer.main import setup_postgres
from danswer.main import setup_vespa
from tests.integration.common_utils.llm import seed_default_openai_provider
def _run_migrations(
@ -167,6 +166,4 @@ def reset_all() -> None:
reset_postgres()
print("Resetting Vespa...")
reset_vespa()
print("Seeding LLM Providers...")
seed_default_openai_provider()
print("Finished resetting all.")

View File

@ -1,72 +0,0 @@
import uuid
import requests
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from tests.integration.common_utils.connectors import ConnectorClient
from tests.integration.common_utils.constants import API_SERVER_URL
class SimpleTestDocument(BaseModel):
id: str
content: str
class SeedDocumentResponse(BaseModel):
cc_pair_id: int
documents: list[SimpleTestDocument]
class TestDocumentClient:
@staticmethod
def seed_documents(
num_docs: int = 5, cc_pair_id: int | None = None
) -> SeedDocumentResponse:
if not cc_pair_id:
connector_details = ConnectorClient.create_connector()
cc_pair_id = connector_details.cc_pair_id
# Create and ingest some documents
documents: list[dict] = []
for _ in range(num_docs):
document_id = f"test-doc-{uuid.uuid4()}"
document = {
"document": {
"id": document_id,
"sections": [
{
"text": f"This is test document {document_id}",
"link": f"{document_id}",
}
],
"source": DocumentSource.NOT_APPLICABLE,
# just for testing metadata
"metadata": {"document_id": document_id},
"semantic_identifier": f"Test Document {document_id}",
"from_ingestion_api": True,
},
"cc_pair_id": cc_pair_id,
}
documents.append(document)
response = requests.post(
f"{API_SERVER_URL}/danswer-api/ingestion",
json=document,
)
response.raise_for_status()
print("Seeding completed successfully.")
return SeedDocumentResponse(
cc_pair_id=cc_pair_id,
documents=[
SimpleTestDocument(
id=document["document"]["id"],
content=document["document"]["sections"][0]["text"],
)
for document in documents
],
)
if __name__ == "__main__":
seed_documents_resp = TestDocumentClient.seed_documents()

View File

@ -0,0 +1,120 @@
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from danswer.auth.schemas import UserRole
from danswer.search.enums import RecencyBiasSetting
from danswer.server.documents.models import DocumentSource
from danswer.server.documents.models import InputType
"""
These data models are used to represent the data on the testing side of things.
This means the flow is:
1. Make request that changes data in db
2. Make a change to the testing model
3. Retrieve data from db
4. Compare db data with testing model to verify
"""
class TestAPIKey(BaseModel):
api_key_id: int
api_key_display: str
api_key: str | None = None # only present on initial creation
api_key_name: str | None = None
api_key_role: UserRole
user_id: UUID
headers: dict
class TestUser(BaseModel):
id: str
email: str
password: str
headers: dict
class TestCredential(BaseModel):
id: int
name: str
credential_json: dict[str, Any]
admin_public: bool
source: DocumentSource
curator_public: bool
groups: list[int]
class TestConnector(BaseModel):
id: int
name: str
source: DocumentSource
input_type: InputType
connector_specific_config: dict[str, Any]
groups: list[int] | None = None
is_public: bool | None = None
class SimpleTestDocument(BaseModel):
id: str
content: str
class TestCCPair(BaseModel):
id: int
name: str
connector_id: int
credential_id: int
is_public: bool
groups: list[int]
documents: list[SimpleTestDocument] = Field(default_factory=list)
class TestUserGroup(BaseModel):
id: int
name: str
user_ids: list[str]
cc_pair_ids: list[int]
class TestLLMProvider(BaseModel):
id: int
name: str
provider: str
api_key: str
default_model_name: str
is_public: bool
groups: list[TestUserGroup]
api_base: str | None = None
api_version: str | None = None
class TestDocumentSet(BaseModel):
id: int
name: str
description: str
cc_pair_ids: list[int] = Field(default_factory=list)
is_public: bool
is_up_to_date: bool
users: list[str] = Field(default_factory=list)
groups: list[int] = Field(default_factory=list)
class TestPersona(BaseModel):
id: int
name: str
description: str
num_chunks: float
llm_relevance_filter: bool
is_public: bool
llm_filter_extraction: bool
recency_bias: RecencyBiasSetting
prompt_ids: list[int]
document_set_ids: list[int]
tool_ids: list[int]
llm_model_provider_override: str | None
llm_model_version_override: str | None
users: list[str]
groups: list[int]

View File

@ -1,24 +0,0 @@
from typing import cast
import requests
from ee.danswer.server.user_group.models import UserGroup
from ee.danswer.server.user_group.models import UserGroupCreate
from tests.integration.common_utils.constants import API_SERVER_URL
class UserGroupClient:
@staticmethod
def create_user_group(user_group_creation_request: UserGroupCreate) -> int:
response = requests.post(
f"{API_SERVER_URL}/manage/admin/user-group",
json=user_group_creation_request.model_dump(),
)
response.raise_for_status()
return cast(int, response.json()["id"])
@staticmethod
def fetch_user_groups() -> list[UserGroup]:
response = requests.get(f"{API_SERVER_URL}/manage/admin/user-group")
response.raise_for_status()
return [UserGroup(**ug) for ug in response.json()]

View File

@ -1,3 +1,4 @@
import os
from collections.abc import Generator
import pytest
@ -9,6 +10,25 @@ from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.vespa import TestVespaClient
def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
os.environ[key] = value.strip()
print("Successfully loaded environment variables")
except FileNotFoundError:
print(f"File {env_file} not found")
# Load environment variables at the module level
load_env_vars()
@pytest.fixture
def db_session() -> Generator[Session, None, None]:
with get_session_context_manager() as session:

View File

@ -1,190 +1,305 @@
import time
"""
This file contains tests for the following:
- Ensuring deletion of a connector also:
- deletes the documents in vespa for that connector
- updates the document sets and user groups to remove the connector
- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected
"""
from uuid import uuid4
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from tests.integration.common_utils.connectors import ConnectorClient
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.document_sets import DocumentSetClient
from tests.integration.common_utils.seed_documents import TestDocumentClient
from tests.integration.common_utils.user_groups import UserGroupClient
from tests.integration.common_utils.user_groups import UserGroupCreate
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.document_set import DocumentSetManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import TestAPIKey
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.test_models import TestUserGroup
from tests.integration.common_utils.vespa import TestVespaClient
def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None:
# create connectors
c1_details = ConnectorClient.create_connector(name_prefix="tc1")
c2_details = ConnectorClient.create_connector(name_prefix="tc2")
c1_seed_res = TestDocumentClient.seed_documents(
num_docs=5, cc_pair_id=c1_details.cc_pair_id
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
# add api key to user
api_key: TestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
c2_seed_res = TestDocumentClient.seed_documents(
num_docs=5, cc_pair_id=c2_details.cc_pair_id
# create connectors
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
cc_pair_2 = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
# seed documents
cc_pair_1 = DocumentManager.seed_and_attach_docs(
cc_pair=cc_pair_1,
num_docs=NUM_DOCS,
api_key=api_key,
)
cc_pair_2 = DocumentManager.seed_and_attach_docs(
cc_pair=cc_pair_2,
num_docs=NUM_DOCS,
api_key=api_key,
)
# create document sets
doc_set_1_id = DocumentSetClient.create_document_set(
DocumentSetCreationRequest(
name="Test Document Set 1",
description="Intially connector to be deleted, should be empty after test",
cc_pair_ids=[c1_details.cc_pair_id],
is_public=True,
users=[],
groups=[],
)
doc_set_1 = DocumentSetManager.create(
name="Test Document Set 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
doc_set_2_id = DocumentSetClient.create_document_set(
DocumentSetCreationRequest(
name="Test Document Set 2",
description="Intially both connectors, should contain undeleted connector after test",
cc_pair_ids=[c1_details.cc_pair_id, c2_details.cc_pair_id],
is_public=True,
users=[],
groups=[],
)
doc_set_2 = DocumentSetManager.create(
name="Test Document Set 2",
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
# wait for document sets to be synced
start = time.time()
while True:
doc_sets = DocumentSetClient.fetch_document_sets()
doc_set_1 = next(
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None
)
doc_set_2 = next(
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None
)
if not doc_set_1 or not doc_set_2:
raise RuntimeError("Document set not found")
if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date:
break
if time.time() - start > MAX_DELAY:
raise TimeoutError("Document sets were not synced within the max delay")
time.sleep(2)
DocumentSetManager.wait_for_sync(user_performing_action=admin_user)
print("Document sets created and synced")
# if so, create ACLs
user_group_1 = UserGroupClient.create_user_group(
UserGroupCreate(
name="Test User Group 1", user_ids=[], cc_pair_ids=[c1_details.cc_pair_id]
)
# create user groups
user_group_1: TestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
user_group_2 = UserGroupClient.create_user_group(
UserGroupCreate(
name="Test User Group 2",
user_ids=[],
cc_pair_ids=[c1_details.cc_pair_id, c2_details.cc_pair_id],
)
user_group_2: TestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
# wait for user groups to be available
start = time.time()
while True:
user_groups = {ug.id: ug for ug in UserGroupClient.fetch_user_groups()}
if not (
user_group_1 in user_groups.keys() and user_group_2 in user_groups.keys()
):
raise RuntimeError("User groups not found")
if (
user_groups[user_group_1].is_up_to_date
and user_groups[user_group_2].is_up_to_date
):
break
if time.time() - start > MAX_DELAY:
raise TimeoutError("User groups were not synced within the max delay")
time.sleep(2)
print("User groups created and synced")
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
# delete connector 1
ConnectorClient.update_connector_status(
cc_pair_id=c1_details.cc_pair_id, status=ConnectorCredentialPairStatus.PAUSED
CCPairManager.pause_cc_pair(
cc_pair=cc_pair_1,
user_performing_action=admin_user,
)
ConnectorClient.delete_connector(
connector_id=c1_details.connector_id, credential_id=c1_details.credential_id
CCPairManager.delete(
cc_pair=cc_pair_1,
user_performing_action=admin_user,
)
start = time.time()
while True:
connectors = ConnectorClient.get_connectors()
# Update local records to match the database for later comparison
user_group_1.cc_pair_ids = []
user_group_2.cc_pair_ids = [cc_pair_2.id]
doc_set_1.cc_pair_ids = []
doc_set_2.cc_pair_ids = [cc_pair_2.id]
cc_pair_1.groups = []
cc_pair_2.groups = [user_group_2.id]
if c1_details.connector_id not in [c["id"] for c in connectors]:
break
if time.time() - start > MAX_DELAY:
raise TimeoutError("Connector 1 was not deleted within the max delay")
time.sleep(2)
print("Connector 1 deleted")
CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user)
# validate vespa documents
c1_vespa_docs = vespa_client.get_documents_by_id(
[doc.id for doc in c1_seed_res.documents]
)["documents"]
c2_vespa_docs = vespa_client.get_documents_by_id(
[doc.id for doc in c2_seed_res.documents]
)["documents"]
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
doc_set_names=[],
group_names=[],
doc_creating_user=admin_user,
verify_deleted=True,
)
assert len(c1_vespa_docs) == 0
assert len(c2_vespa_docs) == 5
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_set_names=[doc_set_2.name],
group_names=[user_group_2.name],
doc_creating_user=admin_user,
verify_deleted=False,
)
for doc in c2_vespa_docs:
assert doc["fields"]["access_control_list"] == {
"PUBLIC": 1,
"group:Test User Group 2": 1,
}
assert doc["fields"]["document_sets"] == {"Test Document Set 2": 1}
# check that only connector 1 is deleted
CCPairManager.verify(
cc_pair=cc_pair_2,
user_performing_action=admin_user,
)
# validate document sets
DocumentSetManager.verify(
document_set=doc_set_1,
user_performing_action=admin_user,
)
DocumentSetManager.verify(
document_set=doc_set_2,
user_performing_action=admin_user,
)
# validate user groups
UserGroupManager.verify(
user_group=user_group_1,
user_performing_action=admin_user,
)
UserGroupManager.verify(
user_group=user_group_2,
user_performing_action=admin_user,
)
def test_connector_deletion_for_overlapping_connectors(
reset: None, vespa_client: TestVespaClient
) -> None:
"""Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping
document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors.
"""
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
# add api key to user
api_key: TestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
# create connectors
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
cc_pair_2 = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
doc_ids = [str(uuid4())]
cc_pair_1 = DocumentManager.seed_and_attach_docs(
cc_pair=cc_pair_1,
document_ids=doc_ids,
api_key=api_key,
)
cc_pair_2 = DocumentManager.seed_and_attach_docs(
cc_pair=cc_pair_2,
document_ids=doc_ids,
api_key=api_key,
)
# verify vespa document exists and that it is not in any document sets or groups
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
doc_set_names=[],
group_names=[],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_set_names=[],
group_names=[],
doc_creating_user=admin_user,
)
# create document set
doc_set_1 = DocumentSetManager.create(
name="Test Document Set 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
DocumentSetManager.wait_for_sync(
document_sets_to_check=[doc_set_1],
user_performing_action=admin_user,
)
print("Document set 1 created and synced")
# verify vespa document is in the document set
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
doc_set_names=[doc_set_1.name],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_creating_user=admin_user,
)
# create a user group and attach it to connector 1
user_group_1: TestUserGroup = UserGroupManager.create(
name="Test User Group 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1],
user_performing_action=admin_user,
)
cc_pair_1.groups = [user_group_1.id]
print("User group 1 created and synced")
# create a user group and attach it to connector 2
user_group_2: TestUserGroup = UserGroupManager.create(
name="Test User Group 2",
cc_pair_ids=[cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2],
user_performing_action=admin_user,
)
cc_pair_2.groups = [user_group_2.id]
print("User group 2 created and synced")
# verify vespa document is in the user group
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
# delete connector 1
CCPairManager.pause_cc_pair(
cc_pair=cc_pair_1,
user_performing_action=admin_user,
)
CCPairManager.delete(
cc_pair=cc_pair_1,
user_performing_action=admin_user,
)
# EVERYTHING BELOW HERE IS CURRENTLY BROKEN AND NEEDS TO BE FIXED SERVER SIDE
# wait for deletion to finish
# CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user)
# print("Connector 1 deleted")
# check that only connector 1 is deleted
# TODO: check for the CC pair rather than the connector once the refactor is done
all_connectors = ConnectorClient.get_connectors()
assert len(all_connectors) == 1
assert all_connectors[0]["id"] == c2_details.connector_id
# CCPairManager.verify(
# cc_pair=cc_pair_1,
# verify_deleted=True,
# user_performing_action=admin_user,
# )
# CCPairManager.verify(
# cc_pair=cc_pair_2,
# user_performing_action=admin_user,
# )
# validate document sets
all_doc_sets = DocumentSetClient.fetch_document_sets()
assert len(all_doc_sets) == 2
doc_set_1_found = False
doc_set_2_found = False
for doc_set in all_doc_sets:
if doc_set.id == doc_set_1_id:
doc_set_1_found = True
assert doc_set.cc_pair_descriptors == []
if doc_set.id == doc_set_2_id:
doc_set_2_found = True
assert len(doc_set.cc_pair_descriptors) == 1
assert doc_set.cc_pair_descriptors[0].id == c2_details.cc_pair_id
assert doc_set_1_found
assert doc_set_2_found
# validate user groups
all_user_groups = UserGroupClient.fetch_user_groups()
assert len(all_user_groups) == 2
user_group_1_found = False
user_group_2_found = False
for user_group in all_user_groups:
if user_group.id == user_group_1:
user_group_1_found = True
assert user_group.cc_pairs == []
if user_group.id == user_group_2:
user_group_2_found = True
assert len(user_group.cc_pairs) == 1
assert user_group.cc_pairs[0].id == c2_details.cc_pair_id
assert user_group_1_found
assert user_group_2_found
# verify the document is not in any document sets
# verify the document is only in user group 2
# DocumentManager.verify(
# vespa_client=vespa_client,
# cc_pair=cc_pair_2,
# doc_set_names=[],
# group_names=[user_group_2.name],
# doc_creating_user=admin_user,
# verify_deleted=False,
# )

View File

@ -1,34 +1,59 @@
import requests
from tests.integration.common_utils.connectors import ConnectorClient
from danswer.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.seed_documents import TestDocumentClient
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.llm import LLMProviderManager
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import TestAPIKey
from tests.integration.common_utils.test_models import TestCCPair
from tests.integration.common_utils.test_models import TestUser
def test_send_message_simple_with_history(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
# create connectors
c1_details = ConnectorClient.create_connector(name_prefix="tc1")
c1_seed_res = TestDocumentClient.seed_documents(
num_docs=5, cc_pair_id=c1_details.cc_pair_id
cc_pair_1: TestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
)
api_key: TestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
LLMProviderManager.create(user_performing_action=admin_user)
cc_pair_1 = DocumentManager.seed_and_attach_docs(
cc_pair=cc_pair_1,
num_docs=NUM_DOCS,
api_key=api_key,
)
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [{"message": c1_seed_res.documents[0].content, "role": "user"}],
"messages": [
{
"message": cc_pair_1.documents[0].content,
"role": MessageType.USER.value,
}
],
"persona_id": 0,
"prompt_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# Check that the top document is the correct document
assert response_json["simple_search_docs"][0]["id"] == c1_seed_res.documents[0].id
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
# assert that the metadata is correct
for doc in c1_seed_res.documents:
for doc in cc_pair_1.documents:
found_doc = next(
(x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None
)

View File

@ -1,78 +1,66 @@
import time
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from tests.integration.common_utils.document_sets import DocumentSetClient
from tests.integration.common_utils.seed_documents import TestDocumentClient
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.document_set import DocumentSetManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import TestAPIKey
from tests.integration.common_utils.test_models import TestUser
from tests.integration.common_utils.vespa import TestVespaClient
def test_multiple_document_sets_syncing_same_connnector(
reset: None, vespa_client: TestVespaClient
) -> None:
# Seed documents
seed_result = TestDocumentClient.seed_documents(num_docs=5)
cc_pair_id = seed_result.cc_pair_id
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
# Create first document set
doc_set_1_id = DocumentSetClient.create_document_set(
DocumentSetCreationRequest(
name="Test Document Set 1",
description="First test document set",
cc_pair_ids=[cc_pair_id],
is_public=True,
users=[],
groups=[],
)
# add api key to user
api_key: TestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
doc_set_2_id = DocumentSetClient.create_document_set(
DocumentSetCreationRequest(
name="Test Document Set 2",
description="Second test document set",
cc_pair_ids=[cc_pair_id],
is_public=True,
users=[],
groups=[],
)
# create connector
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
# wait for syncing to be complete
max_delay = 45
start = time.time()
while True:
doc_sets = DocumentSetClient.fetch_document_sets()
doc_set_1 = next(
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None
)
doc_set_2 = next(
(doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None
)
# seed documents
cc_pair_1 = DocumentManager.seed_and_attach_docs(
cc_pair=cc_pair_1,
num_docs=NUM_DOCS,
api_key=api_key,
)
if not doc_set_1 or not doc_set_2:
raise RuntimeError("Document set not found")
# Create document sets
doc_set_1 = DocumentSetManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
doc_set_2 = DocumentSetManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date:
assert [ccp.id for ccp in doc_set_1.cc_pair_descriptors] == [
ccp.id for ccp in doc_set_2.cc_pair_descriptors
]
break
DocumentSetManager.wait_for_sync(
user_performing_action=admin_user,
)
if time.time() - start > max_delay:
raise TimeoutError("Document sets were not synced within the max delay")
time.sleep(2)
# get names so we can compare to what is in vespa
doc_sets = DocumentSetClient.fetch_document_sets()
doc_set_names = {doc_set.name for doc_set in doc_sets}
DocumentSetManager.verify(
document_set=doc_set_1,
user_performing_action=admin_user,
)
DocumentSetManager.verify(
document_set=doc_set_2,
user_performing_action=admin_user,
)
# make sure documents are as expected
seeded_document_ids = [doc.id for doc in seed_result.documents]
result = vespa_client.get_documents_by_id([doc.id for doc in seed_result.documents])
documents = result["documents"]
assert len(documents) == len(seed_result.documents)
assert all(doc["fields"]["document_id"] in seeded_document_ids for doc in documents)
assert all(
set(doc["fields"]["document_sets"].keys()) == doc_set_names for doc in documents
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
doc_set_names=[doc_set_1.name, doc_set_2.name],
doc_creating_user=admin_user,
)

View File

@ -0,0 +1,179 @@
"""
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating connector-credential pairs.
"""
import pytest
from requests.exceptions import HTTPError
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.user import TestUser
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
def test_cc_pair_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
# Creating a curator
curator: TestUser = UserManager.create(name="curator")
# Creating a user group
user_group_1 = UserGroupManager.create(
name="curated_user_group",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# setting the user as a curator for the user group
UserGroupManager.set_curator_status(
test_user_group=user_group_1,
user_to_set_as_curator=curator,
user_performing_action=admin_user,
)
# Creating another user group that the user is not a curator of
user_group_2 = UserGroupManager.create(
name="uncurated_user_group",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# Create a credentials that the curator is and is not curator of
connector_1 = ConnectorManager.create(
name="curator_owned_connector",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
is_public=False,
user_performing_action=admin_user,
)
# currently we dont enforce permissions at the connector level
# pending cc_pair -> connector rework
# connector_2 = ConnectorManager.create(
# name="curator_visible_connector",
# source=DocumentSource.CONFLUENCE,
# groups=[user_group_2.id],
# is_public=False,
# user_performing_action=admin_user,
# )
credential_1 = CredentialManager.create(
name="curator_owned_credential",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
curator_public=False,
user_performing_action=admin_user,
)
credential_2 = CredentialManager.create(
name="curator_visible_credential",
source=DocumentSource.CONFLUENCE,
groups=[user_group_2.id],
curator_public=False,
user_performing_action=admin_user,
)
# END OF HAPPY PATH
"""Tests for things Curators should not be able to do"""
# Curators should not be able to create a public cc pair
with pytest.raises(HTTPError):
CCPairManager.create(
connector_id=connector_1.id,
credential_id=credential_1.id,
name="invalid_cc_pair_1",
groups=[user_group_1.id],
is_public=True,
user_performing_action=curator,
)
# Curators should not be able to create a cc
# pair for a user group they are not a curator of
with pytest.raises(HTTPError):
CCPairManager.create(
connector_id=connector_1.id,
credential_id=credential_1.id,
name="invalid_cc_pair_2",
groups=[user_group_1.id, user_group_2.id],
is_public=False,
user_performing_action=curator,
)
# Curators should not be able to create a cc
# pair without an attached user group
with pytest.raises(HTTPError):
CCPairManager.create(
connector_id=connector_1.id,
credential_id=credential_1.id,
name="invalid_cc_pair_2",
groups=[],
is_public=False,
user_performing_action=curator,
)
# # This test is currently disabled because permissions are
# # not enforced at the connector level
# # Curators should not be able to create a cc pair
# # for a user group that the connector does not belong to (NOT WORKING)
# with pytest.raises(HTTPError):
# CCPairManager.create(
# connector_id=connector_2.id,
# credential_id=credential_1.id,
# name="invalid_cc_pair_3",
# groups=[user_group_1.id],
# is_public=False,
# user_performing_action=curator,
# )
# Curators should not be able to create a cc
# pair for a user group that the credential does not belong to
with pytest.raises(HTTPError):
CCPairManager.create(
connector_id=connector_1.id,
credential_id=credential_2.id,
name="invalid_cc_pair_4",
groups=[user_group_1.id],
is_public=False,
user_performing_action=curator,
)
"""Tests for things Curators should be able to do"""
# Curators should be able to create a private
# cc pair for a user group they are a curator of
valid_cc_pair = CCPairManager.create(
name="valid_cc_pair",
connector_id=connector_1.id,
credential_id=credential_1.id,
groups=[user_group_1.id],
is_public=False,
user_performing_action=curator,
)
# Verify the created cc pair
CCPairManager.verify(
cc_pair=valid_cc_pair,
user_performing_action=curator,
)
# Test pausing the cc pair
CCPairManager.pause_cc_pair(valid_cc_pair, user_performing_action=curator)
# Test deleting the cc pair
CCPairManager.delete(valid_cc_pair, user_performing_action=curator)
CCPairManager.wait_for_deletion_completion(user_performing_action=curator)
CCPairManager.verify(
cc_pair=valid_cc_pair,
verify_deleted=True,
user_performing_action=curator,
)

View File

@ -0,0 +1,136 @@
"""
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating connectors.
"""
import pytest
from requests.exceptions import HTTPError
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.user import TestUser
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
def test_connector_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
# Creating a curator
curator: TestUser = UserManager.create(name="curator")
# Creating a user group
user_group_1 = UserGroupManager.create(
name="user_group_1",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# setting the user as a curator for the user group
UserGroupManager.set_curator_status(
test_user_group=user_group_1,
user_to_set_as_curator=curator,
user_performing_action=admin_user,
)
# Creating another user group that the user is not a curator of
user_group_2 = UserGroupManager.create(
name="user_group_2",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# END OF HAPPY PATH
"""Tests for things Curators should not be able to do"""
# Curators should not be able to create a public connector
with pytest.raises(HTTPError):
ConnectorManager.create(
name="invalid_connector_1",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
is_public=True,
user_performing_action=curator,
)
# Curators should not be able to create a cc pair for a
# user group they are not a curator of
with pytest.raises(HTTPError):
ConnectorManager.create(
name="invalid_connector_2",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id, user_group_2.id],
is_public=False,
user_performing_action=curator,
)
"""Tests for things Curators should be able to do"""
# Curators should be able to create a private
# connector for a user group they are a curator of
valid_connector = ConnectorManager.create(
name="valid_connector",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
is_public=False,
user_performing_action=curator,
)
assert valid_connector.id is not None
# Verify the created connector
created_connector = ConnectorManager.get(
valid_connector.id, user_performing_action=curator
)
assert created_connector.name == valid_connector.name
assert created_connector.source == valid_connector.source
# Verify that the connector can be found in the list of all connectors
all_connectors = ConnectorManager.get_all(user_performing_action=curator)
assert any(conn.id == valid_connector.id for conn in all_connectors)
# Test editing the connector
valid_connector.name = "updated_valid_connector"
ConnectorManager.edit(valid_connector, user_performing_action=curator)
# Verify the edit
updated_connector = ConnectorManager.get(
valid_connector.id, user_performing_action=curator
)
assert updated_connector.name == "updated_valid_connector"
# Test deleting the connector
ConnectorManager.delete(connector=valid_connector, user_performing_action=curator)
# Verify the deletion
all_connectors_after_delete = ConnectorManager.get_all(
user_performing_action=curator
)
assert all(conn.id != valid_connector.id for conn in all_connectors_after_delete)
# Test that curator cannot create a connector for a group they are not a curator of
with pytest.raises(HTTPError):
ConnectorManager.create(
name="invalid_connector_3",
source=DocumentSource.CONFLUENCE,
groups=[user_group_2.id],
is_public=False,
user_performing_action=curator,
)
# Test that curator cannot create a public connector
with pytest.raises(HTTPError):
ConnectorManager.create(
name="invalid_connector_4",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
is_public=True,
user_performing_action=curator,
)

View File

@ -0,0 +1,108 @@
"""
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating credentials.
"""
import pytest
from requests.exceptions import HTTPError
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.user import TestUser
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
def test_credential_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
# Creating a curator
curator: TestUser = UserManager.create(name="curator")
# Creating a user group
user_group_1 = UserGroupManager.create(
name="user_group_1",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# setting the user as a curator for the user group
UserGroupManager.set_curator_status(
test_user_group=user_group_1,
user_to_set_as_curator=curator,
user_performing_action=admin_user,
)
# Creating another user group that the user is not a curator of
user_group_2 = UserGroupManager.create(
name="user_group_2",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# END OF HAPPY PATH
"""Tests for things Curators should not be able to do"""
# Curators should not be able to create a public credential
with pytest.raises(HTTPError):
CredentialManager.create(
name="invalid_credential_1",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
curator_public=True,
user_performing_action=curator,
)
# Curators should not be able to create a credential for a user group they are not a curator of
with pytest.raises(HTTPError):
CredentialManager.create(
name="invalid_credential_2",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id, user_group_2.id],
curator_public=False,
user_performing_action=curator,
)
"""Tests for things Curators should be able to do"""
# Curators should be able to create a private credential for a user group they are a curator of
valid_credential = CredentialManager.create(
name="valid_credential",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
curator_public=False,
user_performing_action=curator,
)
# Verify the created credential
CredentialManager.verify(
credential=valid_credential,
user_performing_action=curator,
)
# Test editing the credential
valid_credential.name = "updated_valid_credential"
CredentialManager.edit(valid_credential, user_performing_action=curator)
# Verify the edit
CredentialManager.verify(
credential=valid_credential,
user_performing_action=curator,
)
# Test deleting the credential
CredentialManager.delete(valid_credential, user_performing_action=curator)
# Verify the deletion
CredentialManager.verify(
credential=valid_credential,
verify_deleted=True,
user_performing_action=curator,
)

View File

@ -0,0 +1,186 @@
import pytest
from requests.exceptions import HTTPError
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document_set import DocumentSetManager
from tests.integration.common_utils.managers.user import TestUser
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
def test_doc_set_permissions_setup(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
# Creating a second user (curator)
curator: TestUser = UserManager.create(name="curator")
# Creating the first user group
user_group_1 = UserGroupManager.create(
name="curated_user_group",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# Setting the curator as a curator for the first user group
UserGroupManager.set_curator_status(
test_user_group=user_group_1,
user_to_set_as_curator=curator,
user_performing_action=admin_user,
)
# Creating a second user group
user_group_2 = UserGroupManager.create(
name="uncurated_user_group",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# Admin creates a cc_pair
private_cc_pair = CCPairManager.create_from_scratch(
is_public=False,
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
# Admin creates a public cc_pair
public_cc_pair = CCPairManager.create_from_scratch(
is_public=True,
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
# END OF HAPPY PATH
"""Tests for things Curators/Admins should not be able to do"""
# Test that curator cannot create a document set for the group they don't curate
with pytest.raises(HTTPError):
DocumentSetManager.create(
name="Invalid Document Set 1",
groups=[user_group_2.id],
cc_pair_ids=[public_cc_pair.id],
user_performing_action=curator,
)
# Test that curator cannot create a document set attached to both groups
with pytest.raises(HTTPError):
DocumentSetManager.create(
name="Invalid Document Set 2",
is_public=False,
cc_pair_ids=[public_cc_pair.id],
groups=[user_group_1.id, user_group_2.id],
user_performing_action=curator,
)
# Test that curator cannot create a document set with no groups
with pytest.raises(HTTPError):
DocumentSetManager.create(
name="Invalid Document Set 3",
is_public=False,
cc_pair_ids=[public_cc_pair.id],
groups=[],
user_performing_action=curator,
)
# Test that curator cannot create a document set with no cc_pairs
with pytest.raises(HTTPError):
DocumentSetManager.create(
name="Invalid Document Set 4",
is_public=False,
cc_pair_ids=[],
groups=[user_group_1.id],
user_performing_action=curator,
)
# Test that admin cannot create a document set with no cc_pairs
with pytest.raises(HTTPError):
DocumentSetManager.create(
name="Invalid Document Set 4",
is_public=False,
cc_pair_ids=[],
groups=[user_group_1.id],
user_performing_action=admin_user,
)
"""Tests for things Curators should be able to do"""
# Test that curator can create a document set for the group they curate
valid_doc_set = DocumentSetManager.create(
name="Valid Document Set",
is_public=False,
cc_pair_ids=[public_cc_pair.id],
groups=[user_group_1.id],
user_performing_action=curator,
)
# Verify that the valid document set was created
DocumentSetManager.verify(
document_set=valid_doc_set,
user_performing_action=admin_user,
)
# Verify that only one document set exists
all_doc_sets = DocumentSetManager.get_all(user_performing_action=admin_user)
assert len(all_doc_sets) == 1
# Add the private_cc_pair to the doc set on our end for later comparison
valid_doc_set.cc_pair_ids.append(private_cc_pair.id)
# Confirm the curator can't add the private_cc_pair to the doc set
with pytest.raises(HTTPError):
DocumentSetManager.edit(
document_set=valid_doc_set,
user_performing_action=curator,
)
# Confirm the admin can't add the private_cc_pair to the doc set
with pytest.raises(HTTPError):
DocumentSetManager.edit(
document_set=valid_doc_set,
user_performing_action=admin_user,
)
# Verify the document set has not been updated in the db
with pytest.raises(ValueError):
DocumentSetManager.verify(
document_set=valid_doc_set,
user_performing_action=admin_user,
)
# Add the private_cc_pair to the user group on our end for later comparison
user_group_1.cc_pair_ids.append(private_cc_pair.id)
# Admin adds the cc_pair to the group the curator curates
UserGroupManager.edit(
user_group=user_group_1,
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
UserGroupManager.verify(
user_group=user_group_1,
user_performing_action=admin_user,
)
# Confirm the curator can now add the cc_pair to the doc set
DocumentSetManager.edit(
document_set=valid_doc_set,
user_performing_action=curator,
)
DocumentSetManager.wait_for_sync(
document_sets_to_check=[valid_doc_set], user_performing_action=admin_user
)
# Verify the updated document set
DocumentSetManager.verify(
document_set=valid_doc_set,
user_performing_action=admin_user,
)

View File

@ -0,0 +1,93 @@
"""
This file tests the ability of different user types to set the role of other users.
"""
import pytest
from requests.exceptions import HTTPError
from danswer.db.models import UserRole
from tests.integration.common_utils.managers.user import TestUser
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
def test_user_role_setting_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
assert UserManager.verify_role(admin_user, UserRole.ADMIN)
# Creating a basic user
basic_user: TestUser = UserManager.create(name="basic_user")
assert UserManager.verify_role(basic_user, UserRole.BASIC)
# Creating a curator
curator: TestUser = UserManager.create(name="curator")
assert UserManager.verify_role(curator, UserRole.BASIC)
# Creating a curator without adding to a group should not work
with pytest.raises(HTTPError):
UserManager.set_role(
user_to_set=curator,
target_role=UserRole.CURATOR,
user_to_perform_action=admin_user,
)
global_curator: TestUser = UserManager.create(name="global_curator")
assert UserManager.verify_role(global_curator, UserRole.BASIC)
# Setting the role of a global curator should not work for a basic user
with pytest.raises(HTTPError):
UserManager.set_role(
user_to_set=global_curator,
target_role=UserRole.GLOBAL_CURATOR,
user_to_perform_action=basic_user,
)
# Setting the role of a global curator should work for an admin user
UserManager.set_role(
user_to_set=global_curator,
target_role=UserRole.GLOBAL_CURATOR,
user_to_perform_action=admin_user,
)
assert UserManager.verify_role(global_curator, UserRole.GLOBAL_CURATOR)
# Setting the role of a global curator should not work for an invalid curator
with pytest.raises(HTTPError):
UserManager.set_role(
user_to_set=global_curator,
target_role=UserRole.BASIC,
user_to_perform_action=global_curator,
)
assert UserManager.verify_role(global_curator, UserRole.GLOBAL_CURATOR)
# Creating a user group
user_group_1 = UserGroupManager.create(
name="user_group_1",
user_ids=[],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# This should fail because the curator is not in the user group
with pytest.raises(HTTPError):
UserGroupManager.set_curator_status(
test_user_group=user_group_1,
user_to_set_as_curator=curator,
user_performing_action=admin_user,
)
# Adding the curator to the user group
user_group_1.user_ids = [curator.id]
UserGroupManager.edit(user_group=user_group_1, user_performing_action=admin_user)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# This should work because the curator is in the user group
UserGroupManager.set_curator_status(
test_user_group=user_group_1,
user_to_set_as_curator=curator,
user_performing_action=admin_user,
)

View File

@ -0,0 +1,86 @@
"""
This test tests the happy path for curator permissions
"""
from danswer.db.models import UserRole
from danswer.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.user import TestUser
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
def test_whole_curator_flow(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: TestUser = UserManager.create(name="admin_user")
assert UserManager.verify_role(admin_user, UserRole.ADMIN)
# Creating a curator
curator: TestUser = UserManager.create(name="curator")
# Creating a user group
user_group_1 = UserGroupManager.create(
name="user_group_1",
user_ids=[curator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1], user_performing_action=admin_user
)
# Making curator a curator of user_group_1
UserGroupManager.set_curator_status(
test_user_group=user_group_1,
user_to_set_as_curator=curator,
user_performing_action=admin_user,
)
assert UserManager.verify_role(curator, UserRole.CURATOR)
# Creating a credential as curator
test_credential = CredentialManager.create(
name="curator_test_credential",
source=DocumentSource.FILE,
curator_public=False,
groups=[user_group_1.id],
user_performing_action=curator,
)
# Creating a connector as curator
test_connector = ConnectorManager.create(
name="curator_test_connector",
source=DocumentSource.FILE,
is_public=False,
groups=[user_group_1.id],
user_performing_action=curator,
)
# Test editing the connector
test_connector.name = "updated_test_connector"
ConnectorManager.edit(connector=test_connector, user_performing_action=curator)
# Creating a CC pair as curator
test_cc_pair = CCPairManager.create(
connector_id=test_connector.id,
credential_id=test_credential.id,
name="curator_test_cc_pair",
groups=[user_group_1.id],
is_public=False,
user_performing_action=curator,
)
CCPairManager.verify(cc_pair=test_cc_pair, user_performing_action=admin_user)
# Verify that the curator can pause and unpause the CC pair
CCPairManager.pause_cc_pair(cc_pair=test_cc_pair, user_performing_action=curator)
# Verify that the curator can delete the CC pair
CCPairManager.delete(cc_pair=test_cc_pair, user_performing_action=curator)
CCPairManager.wait_for_deletion_completion(user_performing_action=curator)
# Verify that the CC pair has been deleted
CCPairManager.verify(
cc_pair=test_cc_pair,
verify_deleted=True,
user_performing_action=admin_user,
)

View File

@ -73,7 +73,7 @@ export function updateCredential(credentialId: number, newDetails: any) {
([key, value]) => key !== "name" && value !== ""
)
);
return fetch(`/api/manage/admin/credentials/${credentialId}`, {
return fetch(`/api/manage/admin/credential/${credentialId}`, {
method: "PUT",
headers: {
"Content-Type": "application/json",
@ -86,7 +86,7 @@ export function updateCredential(credentialId: number, newDetails: any) {
}
export function swapCredential(newCredentialId: number, connectorId: number) {
return fetch(`/api/manage/admin/credentials/swap`, {
return fetch(`/api/manage/admin/credential/swap`, {
method: "PUT",
headers: {
"Content-Type": "application/json",