From 3fec7a6a30541d069a989eac93ef376bcb25ac34 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Fri, 13 Dec 2024 10:05:06 -0800 Subject: [PATCH] post rebase fixes --- .vscode/launch.template.jsonc | 18 +-- backend/onyx/connectors/factory.py | 2 - .../connectors/google_utils/google_auth.py | 104 +++++++++++-- .../google_utils/shared_constants.py | 21 ++- .../onyx/connectors/slack/load_connector.py | 140 ------------------ backend/onyx/server/oauth.py | 48 +++--- .../daily/connectors/google_drive/conftest.py | 5 +- 7 files changed, 143 insertions(+), 195 deletions(-) delete mode 100644 backend/onyx/connectors/slack/load_connector.py diff --git a/.vscode/launch.template.jsonc b/.vscode/launch.template.jsonc index 1f1faed09..5404c4a68 100644 --- a/.vscode/launch.template.jsonc +++ b/.vscode/launch.template.jsonc @@ -17,7 +17,7 @@ } }, { - "name": "Run All Danswer Services", + "name": "Run All Onyx Services", "configurations": [ "Web Server", "Model Server", @@ -122,7 +122,7 @@ "PYTHONUNBUFFERED": "1" }, "args": [ - "danswer.main:app", + "onyx.main:app", "--reload", "--port", "8080" @@ -139,7 +139,7 @@ "consoleName": "Slack Bot", "type": "debugpy", "request": "launch", - "program": "danswer/danswerbot/slack/listener.py", + "program": "onyx/onyxbot/slack/listener.py", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { @@ -166,7 +166,7 @@ }, "args": [ "-A", - "danswer.background.celery.versioned_apps.primary", + "onyx.background.celery.versioned_apps.primary", "worker", "--pool=threads", "--concurrency=4", @@ -195,7 +195,7 @@ }, "args": [ "-A", - "danswer.background.celery.versioned_apps.light", + "onyx.background.celery.versioned_apps.light", "worker", "--pool=threads", "--concurrency=64", @@ -224,7 +224,7 @@ }, "args": [ "-A", - "danswer.background.celery.versioned_apps.heavy", + "onyx.background.celery.versioned_apps.heavy", "worker", "--pool=threads", "--concurrency=4", @@ -254,7 +254,7 @@ }, "args": [ "-A", - "danswer.background.celery.versioned_apps.indexing", + "onyx.background.celery.versioned_apps.indexing", "worker", "--pool=threads", "--concurrency=1", @@ -283,7 +283,7 @@ }, "args": [ "-A", - "danswer.background.celery.versioned_apps.beat", + "onyx.background.celery.versioned_apps.beat", "beat", "--loglevel=INFO", ], @@ -308,7 +308,7 @@ "args": [ "-v" // Specify a sepcific module/test to run or provide nothing to run all tests - //"tests/unit/danswer/llm/answering/test_prune_and_merge.py" + //"tests/unit/onyx/llm/answering/test_prune_and_merge.py" ], "presentation": { "group": "2", diff --git a/backend/onyx/connectors/factory.py b/backend/onyx/connectors/factory.py index 34070edcc..5da314d58 100644 --- a/backend/onyx/connectors/factory.py +++ b/backend/onyx/connectors/factory.py @@ -41,7 +41,6 @@ from onyx.connectors.salesforce.connector import SalesforceConnector from onyx.connectors.sharepoint.connector import SharepointConnector from onyx.connectors.slab.connector import SlabConnector from onyx.connectors.slack.connector import SlackPollConnector -from onyx.connectors.slack.load_connector import SlackLoadConnector from onyx.connectors.teams.connector import TeamsConnector from onyx.connectors.web.connector import WebConnector from onyx.connectors.wikipedia.connector import WikipediaConnector @@ -64,7 +63,6 @@ def identify_connector_class( DocumentSource.WEB: WebConnector, DocumentSource.FILE: LocalFileConnector, DocumentSource.SLACK: { - InputType.LOAD_STATE: SlackLoadConnector, InputType.POLL: SlackPollConnector, InputType.SLIM_RETRIEVAL: SlackPollConnector, }, diff --git a/backend/onyx/connectors/google_utils/google_auth.py b/backend/onyx/connectors/google_utils/google_auth.py index bdca66795..40210e6f3 100644 --- a/backend/onyx/connectors/google_utils/google_auth.py +++ b/backend/onyx/connectors/google_utils/google_auth.py @@ -1,11 +1,16 @@ import json -from typing import cast +from typing import Any from google.auth.transport.requests import Request # type: ignore from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore +from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET from onyx.configs.constants import DocumentSource +from onyx.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) @@ -18,14 +23,40 @@ from onyx.connectors.google_utils.shared_constants import ( from onyx.connectors.google_utils.shared_constants import ( GOOGLE_SCOPES, ) +from onyx.connectors.google_utils.shared_constants import ( + GoogleOAuthAuthenticationMethod, +) from onyx.utils.logger import setup_logger logger = setup_logger() +def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str: + """we really don't want to be persisting the client id and secret anywhere but the + environment. + + Returns a string of serialized json. + """ + + # strip the client id and secret + oauth_creds_json_str = oauth_creds.to_json() + oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str) + oauth_creds_sanitized_json.pop("client_id", None) + oauth_creds_sanitized_json.pop("client_secret", None) + oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json) + return oauth_creds_sanitized_json_str + + def get_google_oauth_creds( token_json_str: str, source: DocumentSource ) -> OAuthCredentials | None: + """creds_json only needs to contain client_id, client_secret and refresh_token to + refresh the creds. + + expiry and token are optional ... however, if passing in expiry, token + should also be passed in or else we may not return any creds. + (probably a sign we should refactor the function) + """ creds_json = json.loads(token_json_str) creds = OAuthCredentials.from_authorized_user_info( info=creds_json, @@ -41,7 +72,7 @@ def get_google_oauth_creds( logger.notice("Refreshed Google Drive tokens.") return creds except Exception: - logger.exception("Failed to refresh google drive access token due to:") + logger.exception("Failed to refresh google drive access token") return None return None @@ -52,31 +83,72 @@ def get_google_creds( source: DocumentSource, ) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: """Checks for two different types of credentials. - (1) A credential which holds a token acquired via a user going thorough + (1) A credential which holds a token acquired via a user going through the Google OAuth flow. (2) A credential which holds a service account key JSON file, which can then be used to impersonate any user in the workspace. + + Return a tuple where: + The first element is the requested credentials + The second element is a new credentials dict that the caller should write back + to the db. This happens if token rotation occurs while loading credentials. """ oauth_creds = None service_creds = None new_creds_dict = None if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials: # OAUTH - access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]) - oauth_creds = get_google_oauth_creds( - token_json_str=access_token_json_str, source=source + authentication_method: str = credentials.get( + DB_CREDENTIALS_AUTHENTICATION_METHOD, + GoogleOAuthAuthenticationMethod.UPLOADED.value, ) - # tell caller to update token stored in DB if it has changed - # (e.g. the token has been refreshed) - new_creds_json_str = oauth_creds.to_json() if oauth_creds else "" - if new_creds_json_str != access_token_json_str: - new_creds_dict = { - DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str, - DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[ - DB_CREDENTIALS_PRIMARY_ADMIN_KEY - ], - } + credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY] + credentials_dict = json.loads(credentials_dict_str) + + # only send what get_google_oauth_creds needs + authorized_user_info = {} + + # oauth_interactive is sanitized and needs credentials from the environment + if ( + authentication_method + == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value + ): + authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID + authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET + else: + authorized_user_info["client_id"] = credentials_dict["client_id"] + authorized_user_info["client_secret"] = credentials_dict["client_secret"] + + authorized_user_info["refresh_token"] = credentials_dict["refresh_token"] + + authorized_user_info["token"] = credentials_dict["token"] + authorized_user_info["expiry"] = credentials_dict["expiry"] + + token_json_str = json.dumps(authorized_user_info) + oauth_creds = get_google_oauth_creds( + token_json_str=token_json_str, source=source + ) + + # tell caller to update token stored in DB if the refresh token changed + if oauth_creds: + if oauth_creds.refresh_token != authorized_user_info["refresh_token"]: + # if oauth_interactive, sanitize the credentials so they don't get stored in the db + if ( + authentication_method + == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value + ): + oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds) + else: + oauth_creds_json_str = oauth_creds.to_json() + + new_creds_dict = { + DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str, + DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[ + DB_CREDENTIALS_PRIMARY_ADMIN_KEY + ], + DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method, + } elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials: # SERVICE ACCOUNT service_account_key_json_str = credentials[ diff --git a/backend/onyx/connectors/google_utils/shared_constants.py b/backend/onyx/connectors/google_utils/shared_constants.py index b896c6e3e..9e1a75536 100644 --- a/backend/onyx/connectors/google_utils/shared_constants.py +++ b/backend/onyx/connectors/google_utils/shared_constants.py @@ -1,3 +1,5 @@ +from enum import Enum as PyEnum + from onyx.configs.constants import DocumentSource # NOTE: do not need https://www.googleapis.com/auth/documents.readonly @@ -10,7 +12,7 @@ GOOGLE_SCOPES = { "https://www.googleapis.com/auth/admin.directory.user.readonly", ], DocumentSource.GMAIL: [ - "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/gmail.readonfromly", "https://www.googleapis.com/auth/admin.directory.user.readonly", "https://www.googleapis.com/auth/admin.directory.group.readonly", ], @@ -23,15 +25,28 @@ DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key" # The email saved for both auth types DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin" +# https://developers.google.com/workspace/guides/create-credentials +# Internally defined authentication method type. +# The value must be one of "oauth_interactive" or "uploaded" +# Used to disambiguate whether credentials have already been created via +# certain methods and what actions we allow users to take +DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method" + + +class GoogleOAuthAuthenticationMethod(str, PyEnum): + OAUTH_INTERACTIVE = "oauth_interactive" + UPLOADED = "uploaded" + + USER_FIELDS = "nextPageToken, users(primaryEmail)" # Error message substrings MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested" # Documentation and error messages -SCOPE_DOC_URL = "https://docs.onyx.app/connectors/google_drive/overview" +SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview" ONYX_SCOPE_INSTRUCTIONS = ( - "You have upgraded Onyx without updating the Google Auth scopes. " + "You have upgraded Danswer without updating the Google Auth scopes. " f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}" ) diff --git a/backend/onyx/connectors/slack/load_connector.py b/backend/onyx/connectors/slack/load_connector.py deleted file mode 100644 index 492a41c7e..000000000 --- a/backend/onyx/connectors/slack/load_connector.py +++ /dev/null @@ -1,140 +0,0 @@ -import json -import os -from datetime import datetime -from datetime import timezone -from pathlib import Path -from typing import Any -from typing import cast - -from onyx.configs.app_configs import INDEX_BATCH_SIZE -from onyx.configs.constants import DocumentSource -from onyx.connectors.interfaces import GenerateDocumentsOutput -from onyx.connectors.interfaces import LoadConnector -from onyx.connectors.models import Document -from onyx.connectors.models import Section -from onyx.connectors.slack.connector import filter_channels -from onyx.connectors.slack.utils import get_message_link -from onyx.utils.logger import setup_logger - - -logger = setup_logger() - - -def get_event_time(event: dict[str, Any]) -> datetime | None: - ts = event.get("ts") - if not ts: - return None - return datetime.fromtimestamp(float(ts), tz=timezone.utc) - - -class SlackLoadConnector(LoadConnector): - # WARNING: DEPRECATED, DO NOT USE - def __init__( - self, - workspace: str, - export_path_str: str, - channels: list[str] | None = None, - # if specified, will treat the specified channel strings as - # regexes, and will only index channels that fully match the regexes - channel_regex_enabled: bool = False, - batch_size: int = INDEX_BATCH_SIZE, - ) -> None: - self.workspace = workspace - self.channels = channels - self.channel_regex_enabled = channel_regex_enabled - self.export_path_str = export_path_str - self.batch_size = batch_size - - def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: - if credentials: - logger.warning("Unexpected credentials provided for Slack Load Connector") - return None - - @staticmethod - def _process_batch_event( - slack_event: dict[str, Any], - channel: dict[str, Any], - matching_doc: Document | None, - workspace: str, - ) -> Document | None: - if ( - slack_event["type"] == "message" - and slack_event.get("subtype") != "channel_join" - ): - if matching_doc: - return Document( - id=matching_doc.id, - sections=matching_doc.sections - + [ - Section( - link=get_message_link( - event=slack_event, - workspace=workspace, - channel_id=channel["id"], - ), - text=slack_event["text"], - ) - ], - source=matching_doc.source, - semantic_identifier=matching_doc.semantic_identifier, - title="", # slack docs don't really have a "title" - doc_updated_at=get_event_time(slack_event), - metadata=matching_doc.metadata, - ) - - return Document( - id=slack_event["ts"], - sections=[ - Section( - link=get_message_link( - event=slack_event, - workspace=workspace, - channel_id=channel["id"], - ), - text=slack_event["text"], - ) - ], - source=DocumentSource.SLACK, - semantic_identifier=channel["name"], - title="", # slack docs don't really have a "title" - doc_updated_at=get_event_time(slack_event), - metadata={}, - ) - - return None - - def load_from_state(self) -> GenerateDocumentsOutput: - export_path = Path(self.export_path_str) - - with open(export_path / "channels.json") as f: - all_channels = json.load(f) - - filtered_channels = filter_channels( - all_channels, self.channels, self.channel_regex_enabled - ) - - document_batch: dict[str, Document] = {} - for channel_info in filtered_channels: - channel_dir_path = export_path / cast(str, channel_info["name"]) - channel_file_paths = [ - channel_dir_path / file_name - for file_name in os.listdir(channel_dir_path) - ] - for path in channel_file_paths: - with open(path) as f: - events = cast(list[dict[str, Any]], json.load(f)) - for slack_event in events: - doc = self._process_batch_event( - slack_event=slack_event, - channel=channel_info, - matching_doc=document_batch.get( - slack_event.get("thread_ts", "") - ), - workspace=self.workspace, - ) - if doc: - document_batch[doc.id] = doc - if len(document_batch) >= self.batch_size: - yield list(document_batch.values()) - - yield list(document_batch.values()) diff --git a/backend/onyx/server/oauth.py b/backend/onyx/server/oauth.py index f126fb60b..66630d875 100644 --- a/backend/onyx/server/oauth.py +++ b/backend/onyx/server/oauth.py @@ -12,36 +12,36 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel from sqlalchemy.orm import Session -from danswer.auth.users import current_user -from danswer.configs.app_configs import WEB_DOMAIN -from danswer.configs.constants import DocumentSource -from danswer.connectors.google_utils.google_auth import get_google_oauth_creds -from danswer.connectors.google_utils.google_auth import sanitize_oauth_credentials -from danswer.connectors.google_utils.shared_constants import ( +from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET +from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET +from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET +from onyx.auth.users import current_user +from onyx.configs.app_configs import WEB_DOMAIN +from onyx.configs.constants import DocumentSource +from onyx.connectors.google_utils.google_auth import get_google_oauth_creds +from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials +from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_AUTHENTICATION_METHOD, ) -from danswer.connectors.google_utils.shared_constants import ( +from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_TOKEN_KEY, ) -from danswer.connectors.google_utils.shared_constants import ( +from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) -from danswer.connectors.google_utils.shared_constants import ( +from onyx.connectors.google_utils.shared_constants import ( GoogleOAuthAuthenticationMethod, ) -from danswer.db.credentials import create_credential -from danswer.db.engine import get_current_tenant_id -from danswer.db.engine import get_session -from danswer.db.models import User -from danswer.redis.redis_pool import get_redis_client -from danswer.server.documents.models import CredentialBase -from danswer.utils.logger import setup_logger -from ee.danswer.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID -from ee.danswer.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET -from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID -from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET -from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_ID -from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET +from onyx.db.credentials import create_credential +from onyx.db.engine import get_current_tenant_id +from onyx.db.engine import get_session +from onyx.db.models import User +from onyx.redis.redis_pool import get_redis_client +from onyx.server.documents.models import CredentialBase +from onyx.utils.logger import setup_logger logger = setup_logger() @@ -64,7 +64,7 @@ class SlackOAuth: TOKEN_URL = "https://slack.com/api/oauth.v2.access" - # SCOPE is per https://docs.danswer.dev/connectors/slack + # SCOPE is per https://docs.onyx.app/connectors/slack BOT_SCOPE = ( "channels:history," "channels:read," @@ -211,7 +211,7 @@ class GoogleDriveOAuth: TOKEN_URL = "https://oauth2.googleapis.com/token" - # SCOPE is per https://docs.danswer.dev/connectors/google-drive + # SCOPE is per https://docs.onyx.app/connectors/google-drive # TODO: Merge with or use google_utils.GOOGLE_SCOPES SCOPE = ( "https://www.googleapis.com/auth/drive.readonly%20" diff --git a/backend/tests/daily/connectors/google_drive/conftest.py b/backend/tests/daily/connectors/google_drive/conftest.py index 485eda64d..b8bf0cb5f 100644 --- a/backend/tests/daily/connectors/google_drive/conftest.py +++ b/backend/tests/daily/connectors/google_drive/conftest.py @@ -5,6 +5,9 @@ from collections.abc import Callable import pytest from onyx.connectors.google_drive.connector import GoogleDriveConnector +from onyx.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) @@ -14,7 +17,7 @@ from onyx.connectors.google_utils.shared_constants import ( from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) -from danswer.connectors.google_utils.shared_constants import ( +from onyx.connectors.google_utils.shared_constants import ( GoogleOAuthAuthenticationMethod, ) from tests.load_env_vars import load_env_vars