post rebase fixes

This commit is contained in:
pablodanswer 2024-12-13 10:05:06 -08:00
parent 21ec5ed795
commit 3fec7a6a30
7 changed files with 143 additions and 195 deletions

View File

@ -17,7 +17,7 @@
}
},
{
"name": "Run All Danswer Services",
"name": "Run All Onyx Services",
"configurations": [
"Web Server",
"Model Server",
@ -122,7 +122,7 @@
"PYTHONUNBUFFERED": "1"
},
"args": [
"danswer.main:app",
"onyx.main:app",
"--reload",
"--port",
"8080"
@ -139,7 +139,7 @@
"consoleName": "Slack Bot",
"type": "debugpy",
"request": "launch",
"program": "danswer/danswerbot/slack/listener.py",
"program": "onyx/onyxbot/slack/listener.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
@ -166,7 +166,7 @@
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.primary",
"onyx.background.celery.versioned_apps.primary",
"worker",
"--pool=threads",
"--concurrency=4",
@ -195,7 +195,7 @@
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.light",
"onyx.background.celery.versioned_apps.light",
"worker",
"--pool=threads",
"--concurrency=64",
@ -224,7 +224,7 @@
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.heavy",
"onyx.background.celery.versioned_apps.heavy",
"worker",
"--pool=threads",
"--concurrency=4",
@ -254,7 +254,7 @@
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.indexing",
"onyx.background.celery.versioned_apps.indexing",
"worker",
"--pool=threads",
"--concurrency=1",
@ -283,7 +283,7 @@
},
"args": [
"-A",
"danswer.background.celery.versioned_apps.beat",
"onyx.background.celery.versioned_apps.beat",
"beat",
"--loglevel=INFO",
],
@ -308,7 +308,7 @@
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
//"tests/unit/onyx/llm/answering/test_prune_and_merge.py"
],
"presentation": {
"group": "2",

View File

@ -41,7 +41,6 @@ from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.slab.connector import SlabConnector
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.connectors.slack.load_connector import SlackLoadConnector
from onyx.connectors.teams.connector import TeamsConnector
from onyx.connectors.web.connector import WebConnector
from onyx.connectors.wikipedia.connector import WikipediaConnector
@ -64,7 +63,6 @@ def identify_connector_class(
DocumentSource.WEB: WebConnector,
DocumentSource.FILE: LocalFileConnector,
DocumentSource.SLACK: {
InputType.LOAD_STATE: SlackLoadConnector,
InputType.POLL: SlackPollConnector,
InputType.SLIM_RETRIEVAL: SlackPollConnector,
},

View File

@ -1,11 +1,16 @@
import json
from typing import cast
from typing import Any
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
@ -18,14 +23,40 @@ from onyx.connectors.google_utils.shared_constants import (
from onyx.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str:
"""we really don't want to be persisting the client id and secret anywhere but the
environment.
Returns a string of serialized json.
"""
# strip the client id and secret
oauth_creds_json_str = oauth_creds.to_json()
oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str)
oauth_creds_sanitized_json.pop("client_id", None)
oauth_creds_sanitized_json.pop("client_secret", None)
oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json)
return oauth_creds_sanitized_json_str
def get_google_oauth_creds(
token_json_str: str, source: DocumentSource
) -> OAuthCredentials | None:
"""creds_json only needs to contain client_id, client_secret and refresh_token to
refresh the creds.
expiry and token are optional ... however, if passing in expiry, token
should also be passed in or else we may not return any creds.
(probably a sign we should refactor the function)
"""
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(
info=creds_json,
@ -41,7 +72,7 @@ def get_google_oauth_creds(
logger.notice("Refreshed Google Drive tokens.")
return creds
except Exception:
logger.exception("Failed to refresh google drive access token due to:")
logger.exception("Failed to refresh google drive access token")
return None
return None
@ -52,31 +83,72 @@ def get_google_creds(
source: DocumentSource,
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
"""Checks for two different types of credentials.
(1) A credential which holds a token acquired via a user going thorough
(1) A credential which holds a token acquired via a user going through
the Google OAuth flow.
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
Return a tuple where:
The first element is the requested credentials
The second element is a new credentials dict that the caller should write back
to the db. This happens if token rotation occurs while loading credentials.
"""
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
# OAUTH
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_oauth_creds(
token_json_str=access_token_json_str, source=source
authentication_method: str = credentials.get(
DB_CREDENTIALS_AUTHENTICATION_METHOD,
GoogleOAuthAuthenticationMethod.UPLOADED.value,
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
],
}
credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
credentials_dict = json.loads(credentials_dict_str)
# only send what get_google_oauth_creds needs
authorized_user_info = {}
# oauth_interactive is sanitized and needs credentials from the environment
if (
authentication_method
== GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
):
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
else:
authorized_user_info["client_id"] = credentials_dict["client_id"]
authorized_user_info["client_secret"] = credentials_dict["client_secret"]
authorized_user_info["refresh_token"] = credentials_dict["refresh_token"]
authorized_user_info["token"] = credentials_dict["token"]
authorized_user_info["expiry"] = credentials_dict["expiry"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=source
)
# tell caller to update token stored in DB if the refresh token changed
if oauth_creds:
if oauth_creds.refresh_token != authorized_user_info["refresh_token"]:
# if oauth_interactive, sanitize the credentials so they don't get stored in the db
if (
authentication_method
== GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
):
oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds)
else:
oauth_creds_json_str = oauth_creds.to_json()
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[
DB_CREDENTIALS_PRIMARY_ADMIN_KEY
],
DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method,
}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
# SERVICE ACCOUNT
service_account_key_json_str = credentials[

View File

@ -1,3 +1,5 @@
from enum import Enum as PyEnum
from onyx.configs.constants import DocumentSource
# NOTE: do not need https://www.googleapis.com/auth/documents.readonly
@ -10,7 +12,7 @@ GOOGLE_SCOPES = {
"https://www.googleapis.com/auth/admin.directory.user.readonly",
],
DocumentSource.GMAIL: [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.readonfromly",
"https://www.googleapis.com/auth/admin.directory.user.readonly",
"https://www.googleapis.com/auth/admin.directory.group.readonly",
],
@ -23,15 +25,28 @@ DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key"
# The email saved for both auth types
DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin"
# https://developers.google.com/workspace/guides/create-credentials
# Internally defined authentication method type.
# The value must be one of "oauth_interactive" or "uploaded"
# Used to disambiguate whether credentials have already been created via
# certain methods and what actions we allow users to take
DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method"
class GoogleOAuthAuthenticationMethod(str, PyEnum):
OAUTH_INTERACTIVE = "oauth_interactive"
UPLOADED = "uploaded"
USER_FIELDS = "nextPageToken, users(primaryEmail)"
# Error message substrings
MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested"
# Documentation and error messages
SCOPE_DOC_URL = "https://docs.onyx.app/connectors/google_drive/overview"
SCOPE_DOC_URL = "https://docs.danswer.dev/connectors/google_drive/overview"
ONYX_SCOPE_INSTRUCTIONS = (
"You have upgraded Onyx without updating the Google Auth scopes. "
"You have upgraded Danswer without updating the Google Auth scopes. "
f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}"
)

View File

@ -1,140 +0,0 @@
import json
import os
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import cast
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.slack.connector import filter_channels
from onyx.connectors.slack.utils import get_message_link
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_event_time(event: dict[str, Any]) -> datetime | None:
ts = event.get("ts")
if not ts:
return None
return datetime.fromtimestamp(float(ts), tz=timezone.utc)
class SlackLoadConnector(LoadConnector):
# WARNING: DEPRECATED, DO NOT USE
def __init__(
self,
workspace: str,
export_path_str: str,
channels: list[str] | None = None,
# if specified, will treat the specified channel strings as
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.workspace = workspace
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.export_path_str = export_path_str
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
if credentials:
logger.warning("Unexpected credentials provided for Slack Load Connector")
return None
@staticmethod
def _process_batch_event(
slack_event: dict[str, Any],
channel: dict[str, Any],
matching_doc: Document | None,
workspace: str,
) -> Document | None:
if (
slack_event["type"] == "message"
and slack_event.get("subtype") != "channel_join"
):
if matching_doc:
return Document(
id=matching_doc.id,
sections=matching_doc.sections
+ [
Section(
link=get_message_link(
event=slack_event,
workspace=workspace,
channel_id=channel["id"],
),
text=slack_event["text"],
)
],
source=matching_doc.source,
semantic_identifier=matching_doc.semantic_identifier,
title="", # slack docs don't really have a "title"
doc_updated_at=get_event_time(slack_event),
metadata=matching_doc.metadata,
)
return Document(
id=slack_event["ts"],
sections=[
Section(
link=get_message_link(
event=slack_event,
workspace=workspace,
channel_id=channel["id"],
),
text=slack_event["text"],
)
],
source=DocumentSource.SLACK,
semantic_identifier=channel["name"],
title="", # slack docs don't really have a "title"
doc_updated_at=get_event_time(slack_event),
metadata={},
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
export_path = Path(self.export_path_str)
with open(export_path / "channels.json") as f:
all_channels = json.load(f)
filtered_channels = filter_channels(
all_channels, self.channels, self.channel_regex_enabled
)
document_batch: dict[str, Document] = {}
for channel_info in filtered_channels:
channel_dir_path = export_path / cast(str, channel_info["name"])
channel_file_paths = [
channel_dir_path / file_name
for file_name in os.listdir(channel_dir_path)
]
for path in channel_file_paths:
with open(path) as f:
events = cast(list[dict[str, Any]], json.load(f))
for slack_event in events:
doc = self._process_batch_event(
slack_event=slack_event,
channel=channel_info,
matching_doc=document_batch.get(
slack_event.get("thread_ts", "")
),
workspace=self.workspace,
)
if doc:
document_batch[doc.id] = doc
if len(document_batch) >= self.batch_size:
yield list(document_batch.values())
yield list(document_batch.values())

View File

@ -12,36 +12,36 @@ from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_utils.google_auth import get_google_oauth_creds
from danswer.connectors.google_utils.google_auth import sanitize_oauth_credentials
from danswer.connectors.google_utils.shared_constants import (
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.auth.users import current_user
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from danswer.connectors.google_utils.shared_constants import (
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from danswer.db.credentials import create_credential
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CredentialBase
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.danswer.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.danswer.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -64,7 +64,7 @@ class SlackOAuth:
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.danswer.dev/connectors/slack
# SCOPE is per https://docs.onyx.app/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
@ -211,7 +211,7 @@ class GoogleDriveOAuth:
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
# SCOPE is per https://docs.onyx.app/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"

View File

@ -5,6 +5,9 @@ from collections.abc import Callable
import pytest
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
@ -14,7 +17,7 @@ from onyx.connectors.google_utils.shared_constants import (
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from danswer.connectors.google_utils.shared_constants import (
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from tests.load_env_vars import load_env_vars