danswer/backend/onyx/connectors/credentials_provider.py
rkuo-danswer 909403a648
Feature/confluence oauth (#3477)
* first cut at slack oauth flow

* fix usage of hooks

* fix button spacing

* add additional error logging

* no dev redirect

* early cut at google drive oauth

* second pass

* switch to production uri's

* try handling oauth_interactive differently

* pass through client id and secret if uploaded

* fix call

* fix test

* temporarily disable check for testing

* Revert "temporarily disable check for testing"

This reverts commit 4b5a022a5fe38b05355a561616068af8e969def2.

* support visibility in test

* missed file

* first cut at confluence oauth

* work in progress

* work in progress

* work in progress

* work in progress

* work in progress

* first cut at distributed locking

* WIP to make test work

* add some dev mode affordances and gate usage of redis behind dynamic credentials

* mypy and credentials provider fixes

* WIP

* fix created at

* fix setting initialValue on everything

* remove debugging, fix ??? some TextFormField issues

* npm fixes

* comment cleanup

* fix comments

* pin the size of the card section

* more review fixes

* more fixes

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-28 03:48:51 +00:00

136 lines
4.3 KiB
Python

import uuid
from types import TracebackType
from typing import Any
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import Credential
from onyx.redis.redis_pool import get_redis_client
class OnyxDBCredentialsProvider(
CredentialsProviderInterface["OnyxDBCredentialsProvider"]
):
"""Implementation to allow the connector to callback and update credentials in the db.
Required in cases where credentials can rotate while the connector is running.
"""
LOCK_TTL = 900 # TTL of the lock
def __init__(self, tenant_id: str, connector_name: str, credential_id: int):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_id = credential_id
self.redis_client = get_redis_client(tenant_id=tenant_id)
# lock used to prevent overlapping renewal of credentials
self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}"
self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL)
def __enter__(self) -> "OnyxDBCredentialsProvider":
acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL)
if not acquired:
raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}")
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Release the lock when exiting the context."""
if self._lock and self._lock.owned():
self._lock.release()
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return str(self._credential_id)
def get_credentials(self) -> dict[str, Any]:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
credential = db_session.execute(
select(Credential).where(Credential.id == self._credential_id)
).scalar_one()
if credential is None:
raise ValueError(
f"No credential found: credential={self._credential_id}"
)
return credential.credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
try:
credential = db_session.execute(
select(Credential)
.where(Credential.id == self._credential_id)
.with_for_update()
).scalar_one()
if credential is None:
raise ValueError(
f"No credential found: credential={self._credential_id}"
)
credential.credential_json = credential_json
db_session.commit()
except Exception:
db_session.rollback()
raise
def is_dynamic(self) -> bool:
return True
class OnyxStaticCredentialsProvider(
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
):
"""Implementation (a very simple one!) to handle static credentials."""
def __init__(
self,
tenant_id: str | None,
connector_name: str,
credential_json: dict[str, Any],
):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_json = credential_json
self._provider_key = str(uuid.uuid4())
def __enter__(self) -> "OnyxStaticCredentialsProvider":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return self._provider_key
def get_credentials(self) -> dict[str, Any]:
return self._credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
self._credential_json = credential_json
def is_dynamic(self) -> bool:
return False