Merge branch 'main' of https://github.com/danswer-ai/danswer into bugfix/light_cpu

This commit is contained in:
Richard Kuo (Danswer) 2025-01-10 11:01:12 -08:00
commit 8e25c3c412
13 changed files with 315 additions and 137 deletions

View File

@ -24,7 +24,9 @@ _REQUEST_PAGINATION_LIMIT = 5000
def _get_server_space_permissions(
confluence_client: OnyxConfluence, space_key: str
) -> ExternalAccess:
space_permissions = confluence_client.get_space_permissions(space_key=space_key)
space_permissions = confluence_client.get_all_space_permissions_server(
space_key=space_key
)
viewspace_permissions = []
for permission_category in space_permissions:

View File

@ -14,6 +14,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import Document
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import TaskStatus
from onyx.db.models import TaskQueueState
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@ -41,14 +42,21 @@ def _get_deletion_status(
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not redis_connector.delete.fenced:
return None
if redis_connector.delete.fenced:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.PENDING,
)
return None
def get_deletion_attempt_snapshot(

View File

@ -354,6 +354,33 @@ class OnyxConfluence(Confluence):
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def get_all_space_permissions_server(
self,
space_key: str,
) -> list[dict[str, Any]]:
"""
This is a confluence server specific method that can be used to
fetch the permissions of a space.
This is better logging than calling the get_space_permissions method
because it returns a jsonrpc response.
"""
url = "rpc/json-rpc/confluenceservice-v2"
data = {
"jsonrpc": "2.0",
"method": "getSpacePermissionSets",
"id": 7,
"params": [space_key],
}
response = self.post(url, data=data)
logger.debug(f"jsonrpc response: {response}")
if not response.get("result"):
logger.warning(
f"No jsonrpc response for space permissions for space {space_key}"
f"\nResponse: {response}"
)
return response.get("result", [])
def _validate_connector_configuration(
credentials: dict[str, Any],

View File

@ -17,6 +17,9 @@ from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.google_utils.resources import get_gmail_service
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
@ -29,6 +32,9 @@ from onyx.connectors.google_utils.shared_constants import (
from onyx.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.connectors.google_utils.shared_constants import (
MISSING_SCOPES_ERROR_STR,
)
@ -96,6 +102,7 @@ def update_credential_access_tokens(
user: User,
db_session: Session,
source: DocumentSource,
auth_method: GoogleOAuthAuthenticationMethod,
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred(source)
flow = InstalledAppFlow.from_client_config(
@ -119,6 +126,7 @@ def update_credential_access_tokens(
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: auth_method.value,
}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
@ -129,6 +137,7 @@ def update_credential_access_tokens(
def build_service_account_creds(
source: DocumentSource,
primary_admin_email: str | None = None,
name: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key(source=source)
@ -138,10 +147,15 @@ def build_service_account_creds(
if primary_admin_email:
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.UPLOADED.value
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=source,
name=name,
)

View File

@ -164,17 +164,12 @@ def update_cc_pair_status(
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""This method may wait up to 30 seconds if pausing the connector due to the need to
terminate tasks in progress. Tasks are not guaranteed to terminate within the
timeout.
"""This method returns nearly immediately. It simply sets some signals and
optimistically assumes any running background processes will clean themselves up.
This is done to improve the perceived end user experience.
Returns HTTPStatus.OK if everything finished.
Returns HTTPStatus.ACCEPTED if the connector is being paused, but background tasks
did not finish within the timeout.
"""
WAIT_TIMEOUT = 15.0
still_terminating = False
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
@ -188,73 +183,37 @@ def update_cc_pair_status(
detail="Connection not found for current user's permissions",
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
redis_connector.stop.set_fence(True)
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
while True:
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
if not redis_connector_index.fenced:
continue
try:
redis_connector.stop.set_fence(True)
while True:
logger.debug(
f"Wait for indexing soft termination starting: cc_pair={cc_pair_id}"
)
wait_succeeded = redis_connector.wait_for_indexing_termination(
search_settings_list, WAIT_TIMEOUT
)
if wait_succeeded:
logger.debug(
f"Wait for indexing soft termination succeeded: cc_pair={cc_pair_id}"
)
break
index_payload = redis_connector_index.payload
if not index_payload:
continue
logger.debug(
"Wait for indexing soft termination timed out. "
f"Moving to hard termination: cc_pair={cc_pair_id} timeout={WAIT_TIMEOUT:.2f}"
)
if not index_payload.celery_task_id:
continue
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
if not redis_connector_index.fenced:
continue
# Revoke the task to prevent it from running
primary_app.control.revoke(index_payload.celery_task_id)
index_payload = redis_connector_index.payload
if not index_payload:
continue
# If it is running, then signaling for termination will get the
# watchdog thread to kill the spawned task
redis_connector_index.set_terminate(index_payload.celery_task_id)
if not index_payload.celery_task_id:
continue
# Revoke the task to prevent it from running
primary_app.control.revoke(index_payload.celery_task_id)
# If it is running, then signaling for termination will get the
# watchdog thread to kill the spawned task
redis_connector_index.set_terminate(index_payload.celery_task_id)
logger.debug(
f"Wait for indexing hard termination starting: cc_pair={cc_pair_id}"
)
wait_succeeded = redis_connector.wait_for_indexing_termination(
search_settings_list, WAIT_TIMEOUT
)
if wait_succeeded:
logger.debug(
f"Wait for indexing hard termination succeeded: cc_pair={cc_pair_id}"
)
break
logger.debug(
f"Wait for indexing hard termination timed out: cc_pair={cc_pair_id}"
)
still_terminating = True
break
finally:
redis_connector.stop.set_fence(False)
break
else:
redis_connector.stop.set_fence(False)
update_connector_credential_pair_from_id(
db_session=db_session,
@ -264,14 +223,6 @@ def update_cc_pair_status(
db_session.commit()
if still_terminating:
return JSONResponse(
status_code=HTTPStatus.ACCEPTED,
content={
"message": "Request accepted, background task termination still in progress"
},
)
return JSONResponse(
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
)

View File

@ -53,8 +53,9 @@ from onyx.connectors.google_utils.google_kv import (
upsert_service_account_key,
)
from onyx.connectors.google_utils.google_kv import verify_csrf
from onyx.connectors.google_utils.shared_constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
GoogleOAuthAuthenticationMethod,
)
from onyx.db.connector import create_connector
from onyx.db.connector import delete_connector
@ -314,6 +315,7 @@ def upsert_service_account_credential(
credential_base = build_service_account_creds(
DocumentSource.GOOGLE_DRIVE,
primary_admin_email=service_account_credential_request.google_primary_admin,
name="Service Account (uploaded)",
)
except KvKeyNotFoundError as e:
raise HTTPException(status_code=400, detail=str(e))
@ -408,6 +410,38 @@ def upload_files(
return FileUploadResponse(file_paths=deduped_file_paths)
@router.get("/admin/connector")
def get_connectors_by_credential(
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
credential: int | None = None,
) -> list[ConnectorSnapshot]:
"""Get a list of connectors. Allow filtering by a specific credential id."""
connectors = fetch_connectors(db_session)
filtered_connectors = []
for connector in connectors:
if connector.source == DocumentSource.INGESTION_API:
# don't include INGESTION_API, as it's a system level
# connector not manageable by the user
continue
if credential is not None:
found = False
for cc_pair in connector.credentials:
if credential == cc_pair.credential_id:
found = True
break
if not found:
continue
filtered_connectors.append(ConnectorSnapshot.from_connector_db_model(connector))
return filtered_connectors
# Retrieves most recent failure cases for connectors that are currently failing
@router.get("/admin/connector/failed-indexing-status")
def get_currently_failed_indexing_status(
@ -987,7 +1021,12 @@ def gmail_callback(
credential_id = int(credential_id_cookie)
verify_csrf(credential_id, callback.state)
credentials: Credentials | None = update_credential_access_tokens(
callback.code, credential_id, user, db_session, DocumentSource.GMAIL
callback.code,
credential_id,
user,
db_session,
DocumentSource.GMAIL,
GoogleOAuthAuthenticationMethod.UPLOADED,
)
if credentials is None:
raise HTTPException(
@ -1013,7 +1052,12 @@ def google_drive_callback(
verify_csrf(credential_id, callback.state)
credentials: Credentials | None = update_credential_access_tokens(
callback.code, credential_id, user, db_session, DocumentSource.GOOGLE_DRIVE
callback.code,
credential_id,
user,
db_session,
DocumentSource.GOOGLE_DRIVE,
GoogleOAuthAuthenticationMethod.UPLOADED,
)
if credentials is None:
raise HTTPException(

View File

@ -9,7 +9,6 @@ from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.db.credentials import alter_credential
from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import cleanup_google_drive_credentials
from onyx.db.credentials import create_credential
from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
from onyx.db.credentials import delete_credential
@ -133,8 +132,6 @@ def create_credential_from_model(
# Temporary fix for empty Google App credentials
if credential_info.source == DocumentSource.GMAIL:
cleanup_gmail_credentials(db_session=db_session)
if credential_info.source == DocumentSource.GOOGLE_DRIVE:
cleanup_google_drive_credentials(db_session=db_session)
credential = create_credential(credential_info, user, db_session)
return ObjectCreationIdResponse(

View File

@ -142,19 +142,20 @@ def put_llm_provider(
detail=f"LLM Provider with name {llm_provider.name} already exists",
)
# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.display_model_names is None:
llm_provider.display_model_names = []
if llm_provider.display_model_names is not None:
# Ensure default_model_name and fast_default_model_name are in display_model_names
# This is necessary for custom models and Bedrock/Azure models
if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)
if llm_provider.default_model_name not in llm_provider.display_model_names:
llm_provider.display_model_names.append(llm_provider.default_model_name)
if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(llm_provider.fast_default_model_name)
if (
llm_provider.fast_default_model_name
and llm_provider.fast_default_model_name
not in llm_provider.display_model_names
):
llm_provider.display_model_names.append(
llm_provider.fast_default_model_name
)
try:
return upsert_llm_provider(

View File

@ -4,8 +4,12 @@ from collections.abc import Generator
import pytest
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.reset import reset_all_multitenant
@ -57,6 +61,30 @@ def new_admin_user(reset: None) -> DATestUser | None:
return None
@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create(name="admin_user")
except Exception:
pass
try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
except Exception:
pass
return None
@pytest.fixture
def reset_multitenant() -> None:
reset_all_multitenant()

View File

@ -7,40 +7,12 @@ import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user import UserRole
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create("admin_user")
except Exception:
pass
try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
except Exception:
pass
return None
@pytest.fixture
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
return LLMProviderManager.create(user_performing_action=admin_user)

View File

@ -0,0 +1,120 @@
import uuid
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.test_models import DATestUser
_DEFAULT_MODELS = ["gpt-4", "gpt-4o"]
def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None:
"""Utility function to fetch an LLM provider by ID"""
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
)
assert response.status_code == 200
providers = response.json()
return next((p for p in providers if p["id"] == provider_id), None)
def test_create_llm_provider_without_display_model_names(
admin_user: DATestUser,
) -> None:
"""Test creating an LLM provider without specifying
display_model_names and verify it's null in response"""
# Create LLM provider without model_names
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": str(uuid.uuid4()),
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
# Verify model_names is None/null
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
assert provider_data["display_model_names"] is None
def test_update_llm_provider_model_names(admin_user: DATestUser) -> None:
"""Test updating an LLM provider's model_names"""
# First create provider without model_names
name = str(uuid.uuid4())
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": name,
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": [_DEFAULT_MODELS[0]],
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
# Update with model_names
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
# Verify update
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS
def test_delete_llm_provider(admin_user: DATestUser) -> None:
"""Test deleting an LLM provider"""
# Create a provider
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"name": "test-provider-delete",
"provider": "openai",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
},
)
assert response.status_code == 200
created_provider = response.json()
# Delete the provider
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{created_provider['id']}",
headers=admin_user.headers,
)
assert response.status_code == 200
# Verify provider is deleted by checking it's not in the list
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is None

View File

@ -8,7 +8,13 @@ import { deleteCCPair } from "@/lib/documentDeletion";
import { mutate } from "swr";
import { buildCCPairInfoUrl } from "./lib";
export function DeletionButton({ ccPair }: { ccPair: CCPairFullInfo }) {
export function DeletionButton({
ccPair,
refresh,
}: {
ccPair: CCPairFullInfo;
refresh: () => void;
}) {
const { popup, setPopup } = usePopup();
const isDeleting =
@ -31,14 +37,22 @@ export function DeletionButton({ ccPair }: { ccPair: CCPairFullInfo }) {
{popup}
<Button
variant="destructive"
onClick={() =>
deleteCCPair(
ccPair.connector.id,
ccPair.credential.id,
setPopup,
() => mutate(buildCCPairInfoUrl(ccPair.id))
)
}
onClick={async () => {
try {
// Await the delete operation to ensure it completes
await deleteCCPair(
ccPair.connector.id,
ccPair.credential.id,
setPopup,
() => mutate(buildCCPairInfoUrl(ccPair.id))
);
// Call refresh to update the state after deletion
refresh();
} catch (error) {
console.error("Error deleting connector:", error);
}
}}
icon={FiTrash}
disabled={
ccPair.status === ConnectorCredentialPairStatus.ACTIVE || isDeleting

View File

@ -362,7 +362,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
<div className="flex mt-4">
<div className="mx-auto">
{ccPair.is_editable_for_current_user && (
<DeletionButton ccPair={ccPair} />
<DeletionButton ccPair={ccPair} refresh={refresh} />
)}
</div>
</div>