mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-27 18:22:55 +01:00
* first cut at slack oauth flow * fix usage of hooks * fix button spacing * add additional error logging * no dev redirect * early cut at google drive oauth * second pass * switch to production uri's * try handling oauth_interactive differently * pass through client id and secret if uploaded * fix call * fix test * temporarily disable check for testing * Revert "temporarily disable check for testing" This reverts commit 4b5a022a5fe38b05355a561616068af8e969def2. * support visibility in test * missed file * first cut at confluence oauth * work in progress * work in progress * work in progress * work in progress * work in progress * first cut at distributed locking * WIP to make test work * add some dev mode affordances and gate usage of redis behind dynamic credentials * mypy and credentials provider fixes * WIP * fix created at * fix setting initialValue on everything * remove debugging, fix ??? some TextFormField issues * npm fixes * comment cleanup * fix comments * pin the size of the card section * more review fixes * more fixes --------- Co-authored-by: Richard Kuo <rkuo@rkuo.com> Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
136 lines
4.3 KiB
Python
136 lines
4.3 KiB
Python
import uuid
|
|
from types import TracebackType
|
|
from typing import Any
|
|
|
|
from redis.lock import Lock as RedisLock
|
|
from sqlalchemy import select
|
|
|
|
from onyx.connectors.interfaces import CredentialsProviderInterface
|
|
from onyx.db.engine import get_session_with_tenant
|
|
from onyx.db.models import Credential
|
|
from onyx.redis.redis_pool import get_redis_client
|
|
|
|
|
|
class OnyxDBCredentialsProvider(
|
|
CredentialsProviderInterface["OnyxDBCredentialsProvider"]
|
|
):
|
|
"""Implementation to allow the connector to callback and update credentials in the db.
|
|
Required in cases where credentials can rotate while the connector is running.
|
|
"""
|
|
|
|
LOCK_TTL = 900 # TTL of the lock
|
|
|
|
def __init__(self, tenant_id: str, connector_name: str, credential_id: int):
|
|
self._tenant_id = tenant_id
|
|
self._connector_name = connector_name
|
|
self._credential_id = credential_id
|
|
|
|
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
|
|
|
# lock used to prevent overlapping renewal of credentials
|
|
self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}"
|
|
self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL)
|
|
|
|
def __enter__(self) -> "OnyxDBCredentialsProvider":
|
|
acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL)
|
|
if not acquired:
|
|
raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}")
|
|
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_value: BaseException | None,
|
|
traceback: TracebackType | None,
|
|
) -> None:
|
|
"""Release the lock when exiting the context."""
|
|
if self._lock and self._lock.owned():
|
|
self._lock.release()
|
|
|
|
def get_tenant_id(self) -> str | None:
|
|
return self._tenant_id
|
|
|
|
def get_provider_key(self) -> str:
|
|
return str(self._credential_id)
|
|
|
|
def get_credentials(self) -> dict[str, Any]:
|
|
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
|
credential = db_session.execute(
|
|
select(Credential).where(Credential.id == self._credential_id)
|
|
).scalar_one()
|
|
|
|
if credential is None:
|
|
raise ValueError(
|
|
f"No credential found: credential={self._credential_id}"
|
|
)
|
|
|
|
return credential.credential_json
|
|
|
|
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
|
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
|
try:
|
|
credential = db_session.execute(
|
|
select(Credential)
|
|
.where(Credential.id == self._credential_id)
|
|
.with_for_update()
|
|
).scalar_one()
|
|
|
|
if credential is None:
|
|
raise ValueError(
|
|
f"No credential found: credential={self._credential_id}"
|
|
)
|
|
|
|
credential.credential_json = credential_json
|
|
db_session.commit()
|
|
except Exception:
|
|
db_session.rollback()
|
|
raise
|
|
|
|
def is_dynamic(self) -> bool:
|
|
return True
|
|
|
|
|
|
class OnyxStaticCredentialsProvider(
|
|
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
|
|
):
|
|
"""Implementation (a very simple one!) to handle static credentials."""
|
|
|
|
def __init__(
|
|
self,
|
|
tenant_id: str | None,
|
|
connector_name: str,
|
|
credential_json: dict[str, Any],
|
|
):
|
|
self._tenant_id = tenant_id
|
|
self._connector_name = connector_name
|
|
self._credential_json = credential_json
|
|
|
|
self._provider_key = str(uuid.uuid4())
|
|
|
|
def __enter__(self) -> "OnyxStaticCredentialsProvider":
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_value: BaseException | None,
|
|
traceback: TracebackType | None,
|
|
) -> None:
|
|
pass
|
|
|
|
def get_tenant_id(self) -> str | None:
|
|
return self._tenant_id
|
|
|
|
def get_provider_key(self) -> str:
|
|
return self._provider_key
|
|
|
|
def get_credentials(self) -> dict[str, Any]:
|
|
return self._credential_json
|
|
|
|
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
|
self._credential_json = credential_json
|
|
|
|
def is_dynamic(self) -> bool:
|
|
return False
|