mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 04:49:29 +02:00
Merge branch 'main' of https://github.com/danswer-ai/danswer into bugfix/light_cpu
This commit is contained in:
commit
8e25c3c412
@ -24,7 +24,9 @@ _REQUEST_PAGINATION_LIMIT = 5000
|
||||
def _get_server_space_permissions(
|
||||
confluence_client: OnyxConfluence, space_key: str
|
||||
) -> ExternalAccess:
|
||||
space_permissions = confluence_client.get_space_permissions(space_key=space_key)
|
||||
space_permissions = confluence_client.get_all_space_permissions_server(
|
||||
space_key=space_key
|
||||
)
|
||||
|
||||
viewspace_permissions = []
|
||||
for permission_category in space_permissions:
|
||||
|
@ -14,6 +14,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.models import TaskQueueState
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@ -41,14 +42,21 @@ def _get_deletion_status(
|
||||
return None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair.id)
|
||||
if not redis_connector.delete.fenced:
|
||||
return None
|
||||
if redis_connector.delete.fenced:
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.STARTED,
|
||||
)
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.STARTED,
|
||||
)
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
return TaskQueueState(
|
||||
task_id="",
|
||||
task_name=redis_connector.delete.fence_key,
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
|
@ -354,6 +354,33 @@ class OnyxConfluence(Confluence):
|
||||
group_name = quote(group_name)
|
||||
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
|
||||
|
||||
def get_all_space_permissions_server(
|
||||
self,
|
||||
space_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
This is a confluence server specific method that can be used to
|
||||
fetch the permissions of a space.
|
||||
This is better logging than calling the get_space_permissions method
|
||||
because it returns a jsonrpc response.
|
||||
"""
|
||||
url = "rpc/json-rpc/confluenceservice-v2"
|
||||
data = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "getSpacePermissionSets",
|
||||
"id": 7,
|
||||
"params": [space_key],
|
||||
}
|
||||
response = self.post(url, data=data)
|
||||
logger.debug(f"jsonrpc response: {response}")
|
||||
if not response.get("result"):
|
||||
logger.warning(
|
||||
f"No jsonrpc response for space permissions for space {space_key}"
|
||||
f"\nResponse: {response}"
|
||||
)
|
||||
|
||||
return response.get("result", [])
|
||||
|
||||
|
||||
def _validate_connector_configuration(
|
||||
credentials: dict[str, Any],
|
||||
|
@ -17,6 +17,9 @@ from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.google_utils.resources import get_gmail_service
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
@ -29,6 +32,9 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GOOGLE_SCOPES,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
MISSING_SCOPES_ERROR_STR,
|
||||
)
|
||||
@ -96,6 +102,7 @@ def update_credential_access_tokens(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
auth_method: GoogleOAuthAuthenticationMethod,
|
||||
) -> OAuthCredentials | None:
|
||||
app_credentials = get_google_app_cred(source)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
@ -119,6 +126,7 @@ def update_credential_access_tokens(
|
||||
new_creds_dict = {
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD: auth_method.value,
|
||||
}
|
||||
|
||||
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
|
||||
@ -129,6 +137,7 @@ def update_credential_access_tokens(
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
primary_admin_email: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key(source=source)
|
||||
|
||||
@ -138,10 +147,15 @@ def build_service_account_creds(
|
||||
if primary_admin_email:
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
|
||||
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.UPLOADED.value
|
||||
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=source,
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
|
@ -164,17 +164,12 @@ def update_cc_pair_status(
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""This method may wait up to 30 seconds if pausing the connector due to the need to
|
||||
terminate tasks in progress. Tasks are not guaranteed to terminate within the
|
||||
timeout.
|
||||
"""This method returns nearly immediately. It simply sets some signals and
|
||||
optimistically assumes any running background processes will clean themselves up.
|
||||
This is done to improve the perceived end user experience.
|
||||
|
||||
Returns HTTPStatus.OK if everything finished.
|
||||
Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks
|
||||
did not finish within the timeout.
|
||||
"""
|
||||
WAIT_TIMEOUT = 15.0
|
||||
still_terminating = False
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
@ -188,73 +183,37 @@ def update_cc_pair_status(
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
|
||||
redis_connector.stop.set_fence(True)
|
||||
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
while True:
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if not redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
try:
|
||||
redis_connector.stop.set_fence(True)
|
||||
while True:
|
||||
logger.debug(
|
||||
f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}"
|
||||
)
|
||||
wait_succeeded = redis_connector.wait_for_indexing_termination(
|
||||
search_settings_list, WAIT_TIMEOUT
|
||||
)
|
||||
if wait_succeeded:
|
||||
logger.debug(
|
||||
f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}"
|
||||
)
|
||||
break
|
||||
index_payload = redis_connector_index.payload
|
||||
if not index_payload:
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
"Wait for indexing soft termination timed out. "
|
||||
f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}"
|
||||
)
|
||||
if not index_payload.celery_task_id:
|
||||
continue
|
||||
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings.id
|
||||
)
|
||||
if not redis_connector_index.fenced:
|
||||
continue
|
||||
# Revoke the task to prevent it from running
|
||||
primary_app.control.revoke(index_payload.celery_task_id)
|
||||
|
||||
index_payload = redis_connector_index.payload
|
||||
if not index_payload:
|
||||
continue
|
||||
# If it is running, then signaling for termination will get the
|
||||
# watchdog thread to kill the spawned task
|
||||
redis_connector_index.set_terminate(index_payload.celery_task_id)
|
||||
|
||||
if not index_payload.celery_task_id:
|
||||
continue
|
||||
|
||||
# Revoke the task to prevent it from running
|
||||
primary_app.control.revoke(index_payload.celery_task_id)
|
||||
|
||||
# If it is running, then signaling for termination will get the
|
||||
# watchdog thread to kill the spawned task
|
||||
redis_connector_index.set_terminate(index_payload.celery_task_id)
|
||||
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}"
|
||||
)
|
||||
wait_succeeded = redis_connector.wait_for_indexing_termination(
|
||||
search_settings_list, WAIT_TIMEOUT
|
||||
)
|
||||
if wait_succeeded:
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}"
|
||||
)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}"
|
||||
)
|
||||
still_terminating = True
|
||||
break
|
||||
finally:
|
||||
redis_connector.stop.set_fence(False)
|
||||
break
|
||||
else:
|
||||
redis_connector.stop.set_fence(False)
|
||||
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
@ -264,14 +223,6 @@ def update_cc_pair_status(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if still_terminating:
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.ACCEPTED,
|
||||
content={
|
||||
"message": "Request accepted, background task termination still in progress"
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
|
||||
)
|
||||
|
@ -53,8 +53,9 @@ from onyx.connectors.google_utils.google_kv import (
|
||||
upsert_service_account_key,
|
||||
)
|
||||
from onyx.connectors.google_utils.google_kv import verify_csrf
|
||||
from onyx.connectors.google_utils.shared_constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.connector import create_connector
|
||||
from onyx.db.connector import delete_connector
|
||||
@ -314,6 +315,7 @@ def upsert_service_account_credential(
|
||||
credential_base = build_service_account_creds(
|
||||
DocumentSource.GOOGLE_DRIVE,
|
||||
primary_admin_email=service_account_credential_request.google_primary_admin,
|
||||
name="Service Account (uploaded)",
|
||||
)
|
||||
except KvKeyNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@ -408,6 +410,38 @@ def upload_files(
|
||||
return FileUploadResponse(file_paths=deduped_file_paths)
|
||||
|
||||
|
||||
@router.get("/admin/connector")
|
||||
def get_connectors_by_credential(
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
credential: int | None = None,
|
||||
) -> list[ConnectorSnapshot]:
|
||||
"""Get a list of connectors. Allow filtering by a specific credential id."""
|
||||
|
||||
connectors = fetch_connectors(db_session)
|
||||
|
||||
filtered_connectors = []
|
||||
for connector in connectors:
|
||||
if connector.source == DocumentSource.INGESTION_API:
|
||||
# don't include INGESTION_API, as it's a system level
|
||||
# connector not manageable by the user
|
||||
continue
|
||||
|
||||
if credential is not None:
|
||||
found = False
|
||||
for cc_pair in connector.credentials:
|
||||
if credential == cc_pair.credential_id:
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
continue
|
||||
|
||||
filtered_connectors.append(ConnectorSnapshot.from_connector_db_model(connector))
|
||||
|
||||
return filtered_connectors
|
||||
|
||||
|
||||
# Retrieves most recent failure cases for connectors that are currently failing
|
||||
@router.get("/admin/connector/failed-indexing-status")
|
||||
def get_currently_failed_indexing_status(
|
||||
@ -987,7 +1021,12 @@ def gmail_callback(
|
||||
credential_id = int(credential_id_cookie)
|
||||
verify_csrf(credential_id, callback.state)
|
||||
credentials: Credentials | None = update_credential_access_tokens(
|
||||
callback.code, credential_id, user, db_session, DocumentSource.GMAIL
|
||||
callback.code,
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
DocumentSource.GMAIL,
|
||||
GoogleOAuthAuthenticationMethod.UPLOADED,
|
||||
)
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
@ -1013,7 +1052,12 @@ def google_drive_callback(
|
||||
verify_csrf(credential_id, callback.state)
|
||||
|
||||
credentials: Credentials | None = update_credential_access_tokens(
|
||||
callback.code, credential_id, user, db_session, DocumentSource.GOOGLE_DRIVE
|
||||
callback.code,
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
DocumentSource.GOOGLE_DRIVE,
|
||||
GoogleOAuthAuthenticationMethod.UPLOADED,
|
||||
)
|
||||
if credentials is None:
|
||||
raise HTTPException(
|
||||
|
@ -9,7 +9,6 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.credentials import alter_credential
|
||||
from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import cleanup_google_drive_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
|
||||
from onyx.db.credentials import delete_credential
|
||||
@ -133,8 +132,6 @@ def create_credential_from_model(
|
||||
# Temporary fix for empty Google App credentials
|
||||
if credential_info.source == DocumentSource.GMAIL:
|
||||
cleanup_gmail_credentials(db_session=db_session)
|
||||
if credential_info.source == DocumentSource.GOOGLE_DRIVE:
|
||||
cleanup_google_drive_credentials(db_session=db_session)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
return ObjectCreationIdResponse(
|
||||
|
@ -142,19 +142,20 @@ def put_llm_provider(
|
||||
detail=f"LLM Provider with name {llm_provider.name} already exists",
|
||||
)
|
||||
|
||||
# Ensure default_model_name and fast_default_model_name are in display_model_names
|
||||
# This is necessary for custom models and Bedrock/Azure models
|
||||
if llm_provider.display_model_names is None:
|
||||
llm_provider.display_model_names = []
|
||||
if llm_provider.display_model_names is not None:
|
||||
# Ensure default_model_name and fast_default_model_name are in display_model_names
|
||||
# This is necessary for custom models and Bedrock/Azure models
|
||||
if llm_provider.default_model_name not in llm_provider.display_model_names:
|
||||
llm_provider.display_model_names.append(llm_provider.default_model_name)
|
||||
|
||||
if llm_provider.default_model_name not in llm_provider.display_model_names:
|
||||
llm_provider.display_model_names.append(llm_provider.default_model_name)
|
||||
|
||||
if (
|
||||
llm_provider.fast_default_model_name
|
||||
and llm_provider.fast_default_model_name not in llm_provider.display_model_names
|
||||
):
|
||||
llm_provider.display_model_names.append(llm_provider.fast_default_model_name)
|
||||
if (
|
||||
llm_provider.fast_default_model_name
|
||||
and llm_provider.fast_default_model_name
|
||||
not in llm_provider.display_model_names
|
||||
):
|
||||
llm_provider.display_model_names.append(
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
|
||||
try:
|
||||
return upsert_llm_provider(
|
||||
|
@ -4,8 +4,12 @@ from collections.abc import Generator
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.reset import reset_all_multitenant
|
||||
@ -57,6 +61,30 @@ def new_admin_user(reset: None) -> DATestUser | None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create(name="admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_multitenant() -> None:
|
||||
reset_all_multitenant()
|
||||
|
@ -7,40 +7,12 @@ import requests
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user import UserRole
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create("admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
||||
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
@ -0,0 +1,120 @@
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]
|
||||
|
||||
|
||||
def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
|
||||
"""Utility function to fetch an LLM provider by ID"""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
return next((p for p in providers if p["id"] == provider_id), None)
|
||||
|
||||
|
||||
def test_create_llm_provider_without_display_model_names(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test creating an LLM provider without specifying
|
||||
display_model_names and verify it's null in response"""
|
||||
# Create LLM provider without model_names
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": str(uuid.uuid4()),
|
||||
"provider": "openai",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
|
||||
# Verify model_names is None/null
|
||||
assert provider_data is not None
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
|
||||
assert provider_data["display_model_names"] is None
|
||||
|
||||
|
||||
def test_update_llm_provider_model_names(admin_user: DATestUser) -> None:
|
||||
"""Test updating an LLM provider's model_names"""
|
||||
# First create provider without model_names
|
||||
name = str(uuid.uuid4())
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": name,
|
||||
"provider": "openai",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": [_DEFAULT_MODELS[0]],
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Update with model_names
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"id": created_provider["id"],
|
||||
"name": name,
|
||||
"provider": created_provider["provider"],
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify update
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is not None
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
|
||||
|
||||
def test_delete_llm_provider(admin_user: DATestUser) -> None:
|
||||
"""Test deleting an LLM provider"""
|
||||
# Create a provider
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "test-provider-delete",
|
||||
"provider": "openai",
|
||||
"default_model_name": _DEFAULT_MODELS[0],
|
||||
"model_names": _DEFAULT_MODELS,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
created_provider = response.json()
|
||||
|
||||
# Delete the provider
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify provider is deleted by checking it's not in the list
|
||||
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
|
||||
assert provider_data is None
|
@ -8,7 +8,13 @@ import { deleteCCPair } from "@/lib/documentDeletion";
|
||||
import { mutate } from "swr";
|
||||
import { buildCCPairInfoUrl } from "./lib";
|
||||
|
||||
export function DeletionButton({ ccPair }: { ccPair: CCPairFullInfo }) {
|
||||
export function DeletionButton({
|
||||
ccPair,
|
||||
refresh,
|
||||
}: {
|
||||
ccPair: CCPairFullInfo;
|
||||
refresh: () => void;
|
||||
}) {
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const isDeleting =
|
||||
@ -31,14 +37,22 @@ export function DeletionButton({ ccPair }: { ccPair: CCPairFullInfo }) {
|
||||
{popup}
|
||||
<Button
|
||||
variant="destructive"
|
||||
onClick={() =>
|
||||
deleteCCPair(
|
||||
ccPair.connector.id,
|
||||
ccPair.credential.id,
|
||||
setPopup,
|
||||
() => mutate(buildCCPairInfoUrl(ccPair.id))
|
||||
)
|
||||
}
|
||||
onClick={async () => {
|
||||
try {
|
||||
// Await the delete operation to ensure it completes
|
||||
await deleteCCPair(
|
||||
ccPair.connector.id,
|
||||
ccPair.credential.id,
|
||||
setPopup,
|
||||
() => mutate(buildCCPairInfoUrl(ccPair.id))
|
||||
);
|
||||
|
||||
// Call refresh to update the state after deletion
|
||||
refresh();
|
||||
} catch (error) {
|
||||
console.error("Error deleting connector:", error);
|
||||
}
|
||||
}}
|
||||
icon={FiTrash}
|
||||
disabled={
|
||||
ccPair.status === ConnectorCredentialPairStatus.ACTIVE || isDeleting
|
||||
|
@ -362,7 +362,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
<div className="flex mt-4">
|
||||
<div className="mx-auto">
|
||||
{ccPair.is_editable_for_current_user && (
|
||||
<DeletionButton ccPair={ccPair} />
|
||||
<DeletionButton ccPair={ccPair} refresh={refresh} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
Loading…
x
Reference in New Issue
Block a user