diff --git a/backend/ee/onyx/external_permissions/confluence/doc_sync.py b/backend/ee/onyx/external_permissions/confluence/doc_sync.py index bd78a8ead..9805cdad6 100644 --- a/backend/ee/onyx/external_permissions/confluence/doc_sync.py +++ b/backend/ee/onyx/external_permissions/confluence/doc_sync.py @@ -24,7 +24,9 @@ _REQUEST_PAGINATION_LIMIT = 5000 def _get_server_space_permissions( confluence_client: OnyxConfluence, space_key: str ) -> ExternalAccess: - space_permissions = confluence_client.get_space_permissions(space_key=space_key) + space_permissions = confluence_client.get_all_space_permissions_server( + space_key=space_key + ) viewspace_permissions = [] for permission_category in space_permissions: diff --git a/backend/onyx/background/celery/celery_utils.py b/backend/onyx/background/celery/celery_utils.py index fc6fef1fa..394dff352 100644 --- a/backend/onyx/background/celery/celery_utils.py +++ b/backend/onyx/background/celery/celery_utils.py @@ -14,6 +14,7 @@ from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SlimConnector from onyx.connectors.models import Document from onyx.db.connector_credential_pair import get_connector_credential_pair +from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import TaskStatus from onyx.db.models import TaskQueueState from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface @@ -41,14 +42,21 @@ def _get_deletion_status( return None redis_connector = RedisConnector(tenant_id, cc_pair.id) - if not redis_connector.delete.fenced: - return None + if redis_connector.delete.fenced: + return TaskQueueState( + task_id="", + task_name=redis_connector.delete.fence_key, + status=TaskStatus.STARTED, + ) - return TaskQueueState( - task_id="", - task_name=redis_connector.delete.fence_key, - status=TaskStatus.STARTED, - ) + if cc_pair.status == ConnectorCredentialPairStatus.DELETING: + return TaskQueueState( + task_id="", + task_name=redis_connector.delete.fence_key, + status=TaskStatus.PENDING, + ) + + return None def get_deletion_attempt_snapshot( diff --git a/backend/onyx/connectors/confluence/onyx_confluence.py b/backend/onyx/connectors/confluence/onyx_confluence.py index d95fa1963..e6a2b957e 100644 --- a/backend/onyx/connectors/confluence/onyx_confluence.py +++ b/backend/onyx/connectors/confluence/onyx_confluence.py @@ -354,6 +354,33 @@ class OnyxConfluence(Confluence): group_name = quote(group_name) yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit) + def get_all_space_permissions_server( + self, + space_key: str, + ) -> list[dict[str, Any]]: + """ + This is a confluence server specific method that can be used to + fetch the permissions of a space. + This is better logging than calling the get_space_permissions method + because it returns a jsonrpc response. + """ + url = "rpc/json-rpc/confluenceservice-v2" + data = { + "jsonrpc": "2.0", + "method": "getSpacePermissionSets", + "id": 7, + "params": [space_key], + } + response = self.post(url, data=data) + logger.debug(f"jsonrpc response: {response}") + if not response.get("result"): + logger.warning( + f"No jsonrpc response for space permissions for space {space_key}" + f"\nResponse: {response}" + ) + + return response.get("result", []) + def _validate_connector_configuration( credentials: dict[str, Any], diff --git a/backend/onyx/connectors/google_utils/google_kv.py b/backend/onyx/connectors/google_utils/google_kv.py index 96785e325..4f714e98f 100644 --- a/backend/onyx/connectors/google_utils/google_kv.py +++ b/backend/onyx/connectors/google_utils/google_kv.py @@ -17,6 +17,9 @@ from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY from onyx.connectors.google_utils.resources import get_drive_service from onyx.connectors.google_utils.resources import get_gmail_service +from onyx.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) @@ -29,6 +32,9 @@ from onyx.connectors.google_utils.shared_constants import ( from onyx.connectors.google_utils.shared_constants import ( GOOGLE_SCOPES, ) +from onyx.connectors.google_utils.shared_constants import ( + GoogleOAuthAuthenticationMethod, +) from onyx.connectors.google_utils.shared_constants import ( MISSING_SCOPES_ERROR_STR, ) @@ -96,6 +102,7 @@ def update_credential_access_tokens( user: User, db_session: Session, source: DocumentSource, + auth_method: GoogleOAuthAuthenticationMethod, ) -> OAuthCredentials | None: app_credentials = get_google_app_cred(source) flow = InstalledAppFlow.from_client_config( @@ -119,6 +126,7 @@ def update_credential_access_tokens( new_creds_dict = { DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str, DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email, + DB_CREDENTIALS_AUTHENTICATION_METHOD: auth_method.value, } if not update_credential_json(credential_id, new_creds_dict, user, db_session): @@ -129,6 +137,7 @@ def update_credential_access_tokens( def build_service_account_creds( source: DocumentSource, primary_admin_email: str | None = None, + name: str | None = None, ) -> CredentialBase: service_account_key = get_service_account_key(source=source) @@ -138,10 +147,15 @@ def build_service_account_creds( if primary_admin_email: credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email + credential_dict[ + DB_CREDENTIALS_AUTHENTICATION_METHOD + ] = GoogleOAuthAuthenticationMethod.UPLOADED.value + return CredentialBase( credential_json=credential_dict, admin_public=True, source=source, + name=name, ) diff --git a/backend/onyx/server/documents/cc_pair.py b/backend/onyx/server/documents/cc_pair.py index cf8746953..64086de5d 100644 --- a/backend/onyx/server/documents/cc_pair.py +++ b/backend/onyx/server/documents/cc_pair.py @@ -164,17 +164,12 @@ def update_cc_pair_status( db_session: Session = Depends(get_session), tenant_id: str | None = Depends(get_current_tenant_id), ) -> JSONResponse: - """This method may wait up to 30 seconds if pausing the connector due to the need to - terminate tasks in progress. Tasks are not guaranteed to terminate within the - timeout. + """This method returns nearly immediately. It simply sets some signals and + optimistically assumes any running background processes will clean themselves up. + This is done to improve the perceived end user experience. Returns HTTPStatus.OK if everything finished. - Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks - did not finish within the timeout. """ - WAIT_TIMEOUT = 15.0 - still_terminating = False - cc_pair = get_connector_credential_pair_from_id( cc_pair_id=cc_pair_id, db_session=db_session, @@ -188,73 +183,37 @@ def update_cc_pair_status( detail="Connection not found for current user's permissions", ) + redis_connector = RedisConnector(tenant_id, cc_pair_id) if status_update_request.status == ConnectorCredentialPairStatus.PAUSED: + redis_connector.stop.set_fence(True) + search_settings_list: list[SearchSettings] = get_active_search_settings( db_session ) - redis_connector = RedisConnector(tenant_id, cc_pair_id) + while True: + for search_settings in search_settings_list: + redis_connector_index = redis_connector.new_index(search_settings.id) + if not redis_connector_index.fenced: + continue - try: - redis_connector.stop.set_fence(True) - while True: - logger.debug( - f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}" - ) - wait_succeeded = redis_connector.wait_for_indexing_termination( - search_settings_list, WAIT_TIMEOUT - ) - if wait_succeeded: - logger.debug( - f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}" - ) - break + index_payload = redis_connector_index.payload + if not index_payload: + continue - logger.debug( - "Wait for indexing soft termination timed out. " - f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}" - ) + if not index_payload.celery_task_id: + continue - for search_settings in search_settings_list: - redis_connector_index = redis_connector.new_index( - search_settings.id - ) - if not redis_connector_index.fenced: - continue + # Revoke the task to prevent it from running + primary_app.control.revoke(index_payload.celery_task_id) - index_payload = redis_connector_index.payload - if not index_payload: - continue + # If it is running, then signaling for termination will get the + # watchdog thread to kill the spawned task + redis_connector_index.set_terminate(index_payload.celery_task_id) - if not index_payload.celery_task_id: - continue - - # Revoke the task to prevent it from running - primary_app.control.revoke(index_payload.celery_task_id) - - # If it is running, then signaling for termination will get the - # watchdog thread to kill the spawned task - redis_connector_index.set_terminate(index_payload.celery_task_id) - - logger.debug( - f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}" - ) - wait_succeeded = redis_connector.wait_for_indexing_termination( - search_settings_list, WAIT_TIMEOUT - ) - if wait_succeeded: - logger.debug( - f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}" - ) - break - - logger.debug( - f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}" - ) - still_terminating = True - break - finally: - redis_connector.stop.set_fence(False) + break + else: + redis_connector.stop.set_fence(False) update_connector_credential_pair_from_id( db_session=db_session, @@ -264,14 +223,6 @@ def update_cc_pair_status( db_session.commit() - if still_terminating: - return JSONResponse( - status_code=HTTPStatus.ACCEPTED, - content={ - "message": "Request accepted, background task termination still in progress" - }, - ) - return JSONResponse( status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)} ) diff --git a/backend/onyx/server/documents/connector.py b/backend/onyx/server/documents/connector.py index 46c74942e..6be024cb2 100644 --- a/backend/onyx/server/documents/connector.py +++ b/backend/onyx/server/documents/connector.py @@ -53,8 +53,9 @@ from onyx.connectors.google_utils.google_kv import ( upsert_service_account_key, ) from onyx.connectors.google_utils.google_kv import verify_csrf +from onyx.connectors.google_utils.shared_constants import DB_CREDENTIALS_DICT_TOKEN_KEY from onyx.connectors.google_utils.shared_constants import ( - DB_CREDENTIALS_DICT_TOKEN_KEY, + GoogleOAuthAuthenticationMethod, ) from onyx.db.connector import create_connector from onyx.db.connector import delete_connector @@ -314,6 +315,7 @@ def upsert_service_account_credential( credential_base = build_service_account_creds( DocumentSource.GOOGLE_DRIVE, primary_admin_email=service_account_credential_request.google_primary_admin, + name="Service Account (uploaded)", ) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -408,6 +410,38 @@ def upload_files( return FileUploadResponse(file_paths=deduped_file_paths) +@router.get("/admin/connector") +def get_connectors_by_credential( + _: User = Depends(current_curator_or_admin_user), + db_session: Session = Depends(get_session), + credential: int | None = None, +) -> list[ConnectorSnapshot]: + """Get a list of connectors. Allow filtering by a specific credential id.""" + + connectors = fetch_connectors(db_session) + + filtered_connectors = [] + for connector in connectors: + if connector.source == DocumentSource.INGESTION_API: + # don't include INGESTION_API, as it's a system level + # connector not manageable by the user + continue + + if credential is not None: + found = False + for cc_pair in connector.credentials: + if credential == cc_pair.credential_id: + found = True + break + + if not found: + continue + + filtered_connectors.append(ConnectorSnapshot.from_connector_db_model(connector)) + + return filtered_connectors + + # Retrieves most recent failure cases for connectors that are currently failing @router.get("/admin/connector/failed-indexing-status") def get_currently_failed_indexing_status( @@ -987,7 +1021,12 @@ def gmail_callback( credential_id = int(credential_id_cookie) verify_csrf(credential_id, callback.state) credentials: Credentials | None = update_credential_access_tokens( - callback.code, credential_id, user, db_session, DocumentSource.GMAIL + callback.code, + credential_id, + user, + db_session, + DocumentSource.GMAIL, + GoogleOAuthAuthenticationMethod.UPLOADED, ) if credentials is None: raise HTTPException( @@ -1013,7 +1052,12 @@ def google_drive_callback( verify_csrf(credential_id, callback.state) credentials: Credentials | None = update_credential_access_tokens( - callback.code, credential_id, user, db_session, DocumentSource.GOOGLE_DRIVE + callback.code, + credential_id, + user, + db_session, + DocumentSource.GOOGLE_DRIVE, + GoogleOAuthAuthenticationMethod.UPLOADED, ) if credentials is None: raise HTTPException( diff --git a/backend/onyx/server/documents/credential.py b/backend/onyx/server/documents/credential.py index 51d9643dc..b68ee660c 100644 --- a/backend/onyx/server/documents/credential.py +++ b/backend/onyx/server/documents/credential.py @@ -9,7 +9,6 @@ from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_user from onyx.db.credentials import alter_credential from onyx.db.credentials import cleanup_gmail_credentials -from onyx.db.credentials import cleanup_google_drive_credentials from onyx.db.credentials import create_credential from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE from onyx.db.credentials import delete_credential @@ -133,8 +132,6 @@ def create_credential_from_model( # Temporary fix for empty Google App credentials if credential_info.source == DocumentSource.GMAIL: cleanup_gmail_credentials(db_session=db_session) - if credential_info.source == DocumentSource.GOOGLE_DRIVE: - cleanup_google_drive_credentials(db_session=db_session) credential = create_credential(credential_info, user, db_session) return ObjectCreationIdResponse( diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index dc36ce649..b5b52f590 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -142,19 +142,20 @@ def put_llm_provider( detail=f"LLM Provider with name {llm_provider.name} already exists", ) - # Ensure default_model_name and fast_default_model_name are in display_model_names - # This is necessary for custom models and Bedrock/Azure models - if llm_provider.display_model_names is None: - llm_provider.display_model_names = [] + if llm_provider.display_model_names is not None: + # Ensure default_model_name and fast_default_model_name are in display_model_names + # This is necessary for custom models and Bedrock/Azure models + if llm_provider.default_model_name not in llm_provider.display_model_names: + llm_provider.display_model_names.append(llm_provider.default_model_name) - if llm_provider.default_model_name not in llm_provider.display_model_names: - llm_provider.display_model_names.append(llm_provider.default_model_name) - - if ( - llm_provider.fast_default_model_name - and llm_provider.fast_default_model_name not in llm_provider.display_model_names - ): - llm_provider.display_model_names.append(llm_provider.fast_default_model_name) + if ( + llm_provider.fast_default_model_name + and llm_provider.fast_default_model_name + not in llm_provider.display_model_names + ): + llm_provider.display_model_names.append( + llm_provider.fast_default_model_name + ) try: return upsert_llm_provider( diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 5eba1e66f..ec50669d0 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -4,8 +4,12 @@ from collections.abc import Generator import pytest from sqlalchemy.orm import Session +from onyx.auth.schemas import UserRole from onyx.db.engine import get_session_context_manager from onyx.db.search_settings import get_current_search_settings +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.managers.user import build_email +from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.reset import reset_all from tests.integration.common_utils.reset import reset_all_multitenant @@ -57,6 +61,30 @@ def new_admin_user(reset: None) -> DATestUser | None: return None +@pytest.fixture +def admin_user() -> DATestUser | None: + try: + return UserManager.create(name="admin_user") + except Exception: + pass + + try: + return UserManager.login_as_user( + DATestUser( + id="", + email=build_email("admin_user"), + password=DEFAULT_PASSWORD, + headers=GENERAL_HEADERS, + role=UserRole.ADMIN, + is_active=True, + ) + ) + except Exception: + pass + + return None + + @pytest.fixture def reset_multitenant() -> None: reset_all_multitenant() diff --git a/backend/tests/integration/openai_assistants_api/conftest.py b/backend/tests/integration/openai_assistants_api/conftest.py index 37ada5cd8..5fc6660ee 100644 --- a/backend/tests/integration/openai_assistants_api/conftest.py +++ b/backend/tests/integration/openai_assistants_api/conftest.py @@ -7,40 +7,12 @@ import requests from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.managers.llm_provider import LLMProviderManager -from tests.integration.common_utils.managers.user import build_email -from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD -from tests.integration.common_utils.managers.user import UserManager -from tests.integration.common_utils.managers.user import UserRole from tests.integration.common_utils.test_models import DATestLLMProvider from tests.integration.common_utils.test_models import DATestUser BASE_URL = f"{API_SERVER_URL}/openai-assistants" -@pytest.fixture -def admin_user() -> DATestUser | None: - try: - return UserManager.create("admin_user") - except Exception: - pass - - try: - return UserManager.login_as_user( - DATestUser( - id="", - email=build_email("admin_user"), - password=DEFAULT_PASSWORD, - headers=GENERAL_HEADERS, - role=UserRole.ADMIN, - is_active=True, - ) - ) - except Exception: - pass - - return None - - @pytest.fixture def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider: return LLMProviderManager.create(user_performing_action=admin_user) diff --git a/backend/tests/integration/tests/llm_provider/test_llm_provider.py b/backend/tests/integration/tests/llm_provider/test_llm_provider.py new file mode 100644 index 000000000..4540f24b2 --- /dev/null +++ b/backend/tests/integration/tests/llm_provider/test_llm_provider.py @@ -0,0 +1,120 @@ +import uuid + +import requests + +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.test_models import DATestUser + + +_DEFAULT_MODELS = ["gpt-4", "gpt-4o"] + + +def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None: + """Utility function to fetch an LLM provider by ID""" + response = requests.get( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + ) + assert response.status_code == 200 + providers = response.json() + return next((p for p in providers if p["id"] == provider_id), None) + + +def test_create_llm_provider_without_display_model_names( + admin_user: DATestUser, +) -> None: + """Test creating an LLM provider without specifying + display_model_names and verify it's null in response""" + # Create LLM provider without model_names + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "name": str(uuid.uuid4()), + "provider": "openai", + "default_model_name": _DEFAULT_MODELS[0], + "model_names": _DEFAULT_MODELS, + "is_public": True, + "groups": [], + }, + ) + assert response.status_code == 200 + created_provider = response.json() + provider_data = _get_provider_by_id(admin_user, created_provider["id"]) + + # Verify model_names is None/null + assert provider_data is not None + assert provider_data["model_names"] == _DEFAULT_MODELS + assert provider_data["default_model_name"] == _DEFAULT_MODELS[0] + assert provider_data["display_model_names"] is None + + +def test_update_llm_provider_model_names(admin_user: DATestUser) -> None: + """Test updating an LLM provider's model_names""" + # First create provider without model_names + name = str(uuid.uuid4()) + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "name": name, + "provider": "openai", + "default_model_name": _DEFAULT_MODELS[0], + "model_names": [_DEFAULT_MODELS[0]], + "is_public": True, + "groups": [], + }, + ) + assert response.status_code == 200 + created_provider = response.json() + + # Update with model_names + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "id": created_provider["id"], + "name": name, + "provider": created_provider["provider"], + "default_model_name": _DEFAULT_MODELS[0], + "model_names": _DEFAULT_MODELS, + "is_public": True, + "groups": [], + }, + ) + assert response.status_code == 200 + + # Verify update + provider_data = _get_provider_by_id(admin_user, created_provider["id"]) + assert provider_data is not None + assert provider_data["model_names"] == _DEFAULT_MODELS + + +def test_delete_llm_provider(admin_user: DATestUser) -> None: + """Test deleting an LLM provider""" + # Create a provider + response = requests.put( + f"{API_SERVER_URL}/admin/llm/provider", + headers=admin_user.headers, + json={ + "name": "test-provider-delete", + "provider": "openai", + "default_model_name": _DEFAULT_MODELS[0], + "model_names": _DEFAULT_MODELS, + "is_public": True, + "groups": [], + }, + ) + assert response.status_code == 200 + created_provider = response.json() + + # Delete the provider + response = requests.delete( + f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}", + headers=admin_user.headers, + ) + assert response.status_code == 200 + + # Verify provider is deleted by checking it's not in the list + provider_data = _get_provider_by_id(admin_user, created_provider["id"]) + assert provider_data is None diff --git a/web/src/app/admin/connector/[ccPairId]/DeletionButton.tsx b/web/src/app/admin/connector/[ccPairId]/DeletionButton.tsx index ccef14b5a..fe430af33 100644 --- a/web/src/app/admin/connector/[ccPairId]/DeletionButton.tsx +++ b/web/src/app/admin/connector/[ccPairId]/DeletionButton.tsx @@ -8,7 +8,13 @@ import { deleteCCPair } from "@/lib/documentDeletion"; import { mutate } from "swr"; import { buildCCPairInfoUrl } from "./lib"; -export function DeletionButton({ ccPair }: { ccPair: CCPairFullInfo }) { +export function DeletionButton({ + ccPair, + refresh, +}: { + ccPair: CCPairFullInfo; + refresh: () => void; +}) { const { popup, setPopup } = usePopup(); const isDeleting = @@ -31,14 +37,22 @@ export function DeletionButton({ ccPair }: { ccPair: CCPairFullInfo }) { {popup}