diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..9ba0e16ce --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @onyx-dot-app/onyx-core-team diff --git a/.github/workflows/nightly-scan-licenses.yml b/.github/workflows/nightly-scan-licenses.yml index 9aa7030e0..d57917981 100644 --- a/.github/workflows/nightly-scan-licenses.yml +++ b/.github/workflows/nightly-scan-licenses.yml @@ -53,24 +53,90 @@ jobs: exclude: '(?i)^(pylint|aio[-_]*).*' - name: Print report - if: ${{ always() }} + if: always() run: echo "${{ steps.license_check_report.outputs.report }}" - name: Install npm dependencies working-directory: ./web run: npm ci - - - name: Run Trivy vulnerability scanner in repo mode - uses: aquasecurity/trivy-action@0.28.0 - with: - scan-type: fs - scanners: license - format: table -# format: sarif -# output: trivy-results.sarif - severity: HIGH,CRITICAL -# - name: Upload Trivy scan results to GitHub Security tab -# uses: github/codeql-action/upload-sarif@v3 + # be careful enabling the sarif and upload as it may spam the security tab + # with a huge amount of items. Work out the issues before enabling upload. +# - name: Run Trivy vulnerability scanner in repo mode +# if: always() +# uses: aquasecurity/trivy-action@0.29.0 # with: -# sarif_file: trivy-results.sarif +# scan-type: fs +# scan-ref: . +# scanners: license +# format: table +# severity: HIGH,CRITICAL +# # format: sarif +# # output: trivy-results.sarif +# +# # - name: Upload Trivy scan results to GitHub Security tab +# # uses: github/codeql-action/upload-sarif@v3 +# # with: +# # sarif_file: trivy-results.sarif + + scan-trivy: + # See https://runs-on.com/runners/linux/ + runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"] + + steps: + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_TOKEN }} + + # Backend + - name: Pull backend docker image + run: docker pull onyxdotapp/onyx-backend:latest + + - name: Run Trivy vulnerability scanner on backend + uses: aquasecurity/trivy-action@0.29.0 + env: + TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2' + TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1' + with: + image-ref: onyxdotapp/onyx-backend:latest + scanners: license + severity: HIGH,CRITICAL + vuln-type: library + exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow + + # Web server + - name: Pull web server docker image + run: docker pull onyxdotapp/onyx-web-server:latest + + - name: Run Trivy vulnerability scanner on web server + uses: aquasecurity/trivy-action@0.29.0 + env: + TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2' + TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1' + with: + image-ref: onyxdotapp/onyx-web-server:latest + scanners: license + severity: HIGH,CRITICAL + vuln-type: library + exit-code: 0 + + # Model server + - name: Pull model server docker image + run: docker pull onyxdotapp/onyx-model-server:latest + + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@0.29.0 + env: + TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2' + TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1' + with: + image-ref: onyxdotapp/onyx-model-server:latest + scanners: license + severity: HIGH,CRITICAL + vuln-type: library + exit-code: 0 \ No newline at end of file diff --git a/backend/alembic/versions/b7c2b63c4a03_add_background_reindex_enabled_field.py b/backend/alembic/versions/b7c2b63c4a03_add_background_reindex_enabled_field.py new file mode 100644 index 000000000..cf31f9f27 --- /dev/null +++ b/backend/alembic/versions/b7c2b63c4a03_add_background_reindex_enabled_field.py @@ -0,0 +1,55 @@ +"""add background_reindex_enabled field + +Revision ID: b7c2b63c4a03 +Revises: f11b408e39d3 +Create Date: 2024-03-26 12:34:56.789012 + +""" +from alembic import op +import sqlalchemy as sa + +from onyx.db.enums import EmbeddingPrecision + + +# revision identifiers, used by Alembic. +revision = "b7c2b63c4a03" +down_revision = "f11b408e39d3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add background_reindex_enabled column with default value of True + op.add_column( + "search_settings", + sa.Column( + "background_reindex_enabled", + sa.Boolean(), + nullable=False, + server_default="true", + ), + ) + + # Add embedding_precision column with default value of FLOAT + op.add_column( + "search_settings", + sa.Column( + "embedding_precision", + sa.Enum(EmbeddingPrecision, native_enum=False), + nullable=False, + server_default=EmbeddingPrecision.FLOAT.name, + ), + ) + + # Add reduced_dimension column with default value of None + op.add_column( + "search_settings", + sa.Column("reduced_dimension", sa.Integer(), nullable=True), + ) + + +def downgrade() -> None: + # Remove the background_reindex_enabled column + op.drop_column("search_settings", "background_reindex_enabled") + op.drop_column("search_settings", "embedding_precision") + op.drop_column("search_settings", "reduced_dimension") diff --git a/backend/alembic/versions/f11b408e39d3_force_lowercase_all_users.py b/backend/alembic/versions/f11b408e39d3_force_lowercase_all_users.py new file mode 100644 index 000000000..48684ea88 --- /dev/null +++ b/backend/alembic/versions/f11b408e39d3_force_lowercase_all_users.py @@ -0,0 +1,36 @@ +"""force lowercase all users + +Revision ID: f11b408e39d3 +Revises: 3bd4c84fe72f +Create Date: 2025-02-26 17:04:55.683500 + +""" + + +# revision identifiers, used by Alembic. +revision = "f11b408e39d3" +down_revision = "3bd4c84fe72f" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # 1) Convert all existing user emails to lowercase + from alembic import op + + op.execute( + """ + UPDATE "user" + SET email = LOWER(email) + """ + ) + + # 2) Add a check constraint to ensure emails are always lowercase + op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)") + + +def downgrade() -> None: + # Drop the check constraint + from alembic import op + + op.drop_constraint("ensure_lowercase_email", "user", type_="check") diff --git a/backend/alembic_tenants/versions/34e3630c7f32_lowercase_multi_tenant_user_auth.py b/backend/alembic_tenants/versions/34e3630c7f32_lowercase_multi_tenant_user_auth.py new file mode 100644 index 000000000..c98fc2ca2 --- /dev/null +++ b/backend/alembic_tenants/versions/34e3630c7f32_lowercase_multi_tenant_user_auth.py @@ -0,0 +1,42 @@ +"""lowercase multi-tenant user auth + +Revision ID: 34e3630c7f32 +Revises: a4f6ee863c47 +Create Date: 2025-02-26 15:03:01.211894 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "34e3630c7f32" +down_revision = "a4f6ee863c47" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # 1) Convert all existing rows to lowercase + op.execute( + """ + UPDATE user_tenant_mapping + SET email = LOWER(email) + """ + ) + # 2) Add a check constraint so that emails cannot be written in uppercase + op.create_check_constraint( + "ensure_lowercase_email", + "user_tenant_mapping", + "email = LOWER(email)", + schema="public", + ) + + +def downgrade() -> None: + # Drop the check constraint + op.drop_constraint( + "ensure_lowercase_email", + "user_tenant_mapping", + schema="public", + type_="check", + ) diff --git a/backend/ee/onyx/configs/app_configs.py b/backend/ee/onyx/configs/app_configs.py index b567db38a..3c2b1638c 100644 --- a/backend/ee/onyx/configs/app_configs.py +++ b/backend/ee/onyx/configs/app_configs.py @@ -59,10 +59,14 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key") OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "") OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "") -OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "") -OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "") -OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "") -OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "") +OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get( + "OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", "" +) +OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get( + "OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", "" +) +OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "") +OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "") OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "") OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get( "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" diff --git a/backend/ee/onyx/db/user_group.py b/backend/ee/onyx/db/user_group.py index c2a36d330..34db3c5f9 100644 --- a/backend/ee/onyx/db/user_group.py +++ b/backend/ee/onyx/db/user_group.py @@ -424,7 +424,7 @@ def _validate_curator_status__no_commit( ) # if the user is a curator in any of their groups, set their role to CURATOR - # otherwise, set their role to BASIC + # otherwise, set their role to BASIC only if they were previously a CURATOR if curator_relationships: user.role = UserRole.CURATOR elif user.role == UserRole.CURATOR: @@ -631,7 +631,16 @@ def update_user_group( removed_users = db_session.scalars( select(User).where(User.id.in_(removed_user_ids)) # type: ignore ).unique() - _validate_curator_status__no_commit(db_session, list(removed_users)) + + # Filter out admin and global curator users before validating curator status + users_to_validate = [ + user + for user in removed_users + if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR] + ] + + if users_to_validate: + _validate_curator_status__no_commit(db_session, users_to_validate) # update "time_updated" to now db_user_group.time_last_modified_by_user = func.now() diff --git a/backend/ee/onyx/external_permissions/confluence/doc_sync.py b/backend/ee/onyx/external_permissions/confluence/doc_sync.py index 507a941b8..8ed076a3c 100644 --- a/backend/ee/onyx/external_permissions/confluence/doc_sync.py +++ b/backend/ee/onyx/external_permissions/confluence/doc_sync.py @@ -9,12 +9,16 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR from onyx.access.models import DocExternalAccess from onyx.access.models import ExternalAccess from onyx.connectors.confluence.connector import ConfluenceConnector +from onyx.connectors.confluence.onyx_confluence import ( + get_user_email_from_username__server, +) from onyx.connectors.confluence.onyx_confluence import OnyxConfluence -from onyx.connectors.confluence.utils import get_user_email_from_username__server +from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.connectors.models import SlimDocument from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger +from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @@ -342,7 +346,8 @@ def _fetch_all_page_restrictions( def confluence_doc_sync( - cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None + cc_pair: ConnectorCredentialPair, + callback: IndexingHeartbeatInterface | None, ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres @@ -354,7 +359,11 @@ def confluence_doc_sync( confluence_connector = ConfluenceConnector( **cc_pair.connector.connector_specific_config ) - confluence_connector.load_credentials(cc_pair.credential.credential_json) + + provider = OnyxDBCredentialsProvider( + get_current_tenant_id(), "confluence", cc_pair.credential_id + ) + confluence_connector.set_credentials_provider(provider) is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) diff --git a/backend/ee/onyx/external_permissions/confluence/group_sync.py b/backend/ee/onyx/external_permissions/confluence/group_sync.py index b11d38f63..b1113a5ab 100644 --- a/backend/ee/onyx/external_permissions/confluence/group_sync.py +++ b/backend/ee/onyx/external_permissions/confluence/group_sync.py @@ -1,9 +1,11 @@ from ee.onyx.db.external_perm import ExternalUserGroup from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME from onyx.background.error_logging import emit_background_error -from onyx.connectors.confluence.onyx_confluence import build_confluence_client +from onyx.connectors.confluence.onyx_confluence import ( + get_user_email_from_username__server, +) from onyx.connectors.confluence.onyx_confluence import OnyxConfluence -from onyx.connectors.confluence.utils import get_user_email_from_username__server +from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.db.models import ConnectorCredentialPair from onyx.utils.logger import setup_logger @@ -61,13 +63,27 @@ def _build_group_member_email_map( def confluence_group_sync( + tenant_id: str, cc_pair: ConnectorCredentialPair, ) -> list[ExternalUserGroup]: - confluence_client = build_confluence_client( - credentials=cc_pair.credential.credential_json, - is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False), - wiki_base=cc_pair.connector.connector_specific_config["wiki_base"], - ) + provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id) + is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) + wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"] + url = wiki_base.rstrip("/") + + probe_kwargs = { + "max_backoff_retries": 6, + "max_backoff_seconds": 10, + } + + final_kwargs = { + "max_backoff_retries": 10, + "max_backoff_seconds": 60, + } + + confluence_client = OnyxConfluence(is_cloud, url, provider) + confluence_client._probe_connection(**probe_kwargs) + confluence_client._initialize_connection(**final_kwargs) group_member_email_map = _build_group_member_email_map( confluence_client=confluence_client, diff --git a/backend/ee/onyx/external_permissions/gmail/doc_sync.py b/backend/ee/onyx/external_permissions/gmail/doc_sync.py index a5563d73b..6f1bae674 100644 --- a/backend/ee/onyx/external_permissions/gmail/doc_sync.py +++ b/backend/ee/onyx/external_permissions/gmail/doc_sync.py @@ -32,7 +32,8 @@ def _get_slim_doc_generator( def gmail_doc_sync( - cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None + cc_pair: ConnectorCredentialPair, + callback: IndexingHeartbeatInterface | None, ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres diff --git a/backend/ee/onyx/external_permissions/google_drive/doc_sync.py b/backend/ee/onyx/external_permissions/google_drive/doc_sync.py index 32f8993d0..8d3df7fa8 100644 --- a/backend/ee/onyx/external_permissions/google_drive/doc_sync.py +++ b/backend/ee/onyx/external_permissions/google_drive/doc_sync.py @@ -145,7 +145,8 @@ def _get_permissions_from_slim_doc( def gdrive_doc_sync( - cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None + cc_pair: ConnectorCredentialPair, + callback: IndexingHeartbeatInterface | None, ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres diff --git a/backend/ee/onyx/external_permissions/google_drive/group_sync.py b/backend/ee/onyx/external_permissions/google_drive/group_sync.py index 7d1a27dbe..241aa4780 100644 --- a/backend/ee/onyx/external_permissions/google_drive/group_sync.py +++ b/backend/ee/onyx/external_permissions/google_drive/group_sync.py @@ -119,6 +119,7 @@ def _build_onyx_groups( def gdrive_group_sync( + tenant_id: str, cc_pair: ConnectorCredentialPair, ) -> list[ExternalUserGroup]: # Initialize connector and build credential/service objects diff --git a/backend/ee/onyx/external_permissions/slack/doc_sync.py b/backend/ee/onyx/external_permissions/slack/doc_sync.py index 9522c906d..0ae9b58cc 100644 --- a/backend/ee/onyx/external_permissions/slack/doc_sync.py +++ b/backend/ee/onyx/external_permissions/slack/doc_sync.py @@ -123,7 +123,8 @@ def _fetch_channel_permissions( def slack_doc_sync( - cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None + cc_pair: ConnectorCredentialPair, + callback: IndexingHeartbeatInterface | None, ) -> list[DocExternalAccess]: """ Adds the external permissions to the documents in postgres diff --git a/backend/ee/onyx/external_permissions/sync_params.py b/backend/ee/onyx/external_permissions/sync_params.py index 8be6dcb2c..9f8ed9681 100644 --- a/backend/ee/onyx/external_permissions/sync_params.py +++ b/backend/ee/onyx/external_permissions/sync_params.py @@ -28,6 +28,7 @@ DocSyncFuncType = Callable[ GroupSyncFuncType = Callable[ [ + str, ConnectorCredentialPair, ], list[ExternalUserGroup], diff --git a/backend/ee/onyx/main.py b/backend/ee/onyx/main.py index cf6b8191c..7d7278bb2 100644 --- a/backend/ee/onyx/main.py +++ b/backend/ee/onyx/main.py @@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import ( ) from ee.onyx.server.manage.standard_answer import router as standard_answer_router from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware -from ee.onyx.server.oauth import router as oauth_router +from ee.onyx.server.oauth.api import router as oauth_router from ee.onyx.server.query_and_chat.chat_backend import ( router as chat_router, ) @@ -152,4 +152,8 @@ def get_application() -> FastAPI: # environment variable. Used to automate deployment for multiple environments. seed_db() + # for debugging discovered routes + # for route in application.router.routes: + # print(f"Path: {route.path}, Methods: {route.methods}") + return application diff --git a/backend/ee/onyx/onyxbot/slack/handlers/handle_standard_answers.py b/backend/ee/onyx/onyxbot/slack/handlers/handle_standard_answers.py index 2da5ea5ca..3f7b08934 100644 --- a/backend/ee/onyx/onyxbot/slack/handlers/handle_standard_answers.py +++ b/backend/ee/onyx/onyxbot/slack/handlers/handle_standard_answers.py @@ -22,7 +22,7 @@ from onyx.onyxbot.slack.blocks import get_restate_blocks from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID from onyx.onyxbot.slack.handlers.utils import send_team_member_message from onyx.onyxbot.slack.models import SlackMessageInfo -from onyx.onyxbot.slack.utils import respond_in_thread +from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import update_emote_react from onyx.utils.logger import OnyxLoggingAdapter from onyx.utils.logger import setup_logger @@ -216,7 +216,7 @@ def _handle_standard_answers( all_blocks = restate_question_blocks + answer_blocks try: - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=message_info.channel_to_respond, receiver_ids=receiver_ids, @@ -231,6 +231,7 @@ def _handle_standard_answers( client=client, channel=message_info.channel_to_respond, thread_ts=slack_thread_id, + receiver_ids=receiver_ids, ) return True diff --git a/backend/ee/onyx/server/oauth/api.py b/backend/ee/onyx/server/oauth/api.py new file mode 100644 index 000000000..f9eb4a751 --- /dev/null +++ b/backend/ee/onyx/server/oauth/api.py @@ -0,0 +1,91 @@ +import base64 +import uuid + +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import JSONResponse + +from ee.onyx.server.oauth.api_router import router +from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth +from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth +from ee.onyx.server.oauth.slack import SlackOAuth +from onyx.auth.users import current_admin_user +from onyx.configs.app_configs import DEV_MODE +from onyx.configs.constants import DocumentSource +from onyx.db.engine import get_current_tenant_id +from onyx.db.models import User +from onyx.redis.redis_pool import get_redis_client +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +@router.post("/prepare-authorization-request") +def prepare_authorization_request( + connector: DocumentSource, + redirect_on_success: str | None, + user: User = Depends(current_admin_user), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """Used by the frontend to generate the url for the user's browser during auth request. + + Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/ + """ + + # create random oauth state param for security and to retrieve user data later + oauth_uuid = uuid.uuid4() + oauth_uuid_str = str(oauth_uuid) + + # urlsafe b64 encode the uuid for the oauth url + oauth_state = ( + base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8") + ) + + session: str | None = None + if connector == DocumentSource.SLACK: + if not DEV_MODE: + oauth_url = SlackOAuth.generate_oauth_url(oauth_state) + else: + oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state) + + session = SlackOAuth.session_dump_json( + email=user.email, redirect_on_success=redirect_on_success + ) + elif connector == DocumentSource.CONFLUENCE: + if not DEV_MODE: + oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state) + else: + oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state) + session = ConfluenceCloudOAuth.session_dump_json( + email=user.email, redirect_on_success=redirect_on_success + ) + elif connector == DocumentSource.GOOGLE_DRIVE: + if not DEV_MODE: + oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state) + else: + oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state) + session = GoogleDriveOAuth.session_dump_json( + email=user.email, redirect_on_success=redirect_on_success + ) + else: + oauth_url = None + + if not oauth_url: + raise HTTPException( + status_code=404, + detail=f"The document source type {connector} does not have OAuth implemented", + ) + + if not session: + raise HTTPException( + status_code=500, + detail=f"The document source type {connector} failed to generate an OAuth session.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # store important session state to retrieve when the user is redirected back + # 10 min is the max we want an oauth flow to be valid + r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600) + + return JSONResponse(content={"url": oauth_url}) diff --git a/backend/ee/onyx/server/oauth/api_router.py b/backend/ee/onyx/server/oauth/api_router.py new file mode 100644 index 000000000..b99ec55a4 --- /dev/null +++ b/backend/ee/onyx/server/oauth/api_router.py @@ -0,0 +1,3 @@ +from fastapi import APIRouter + +router: APIRouter = APIRouter(prefix="/oauth") diff --git a/backend/ee/onyx/server/oauth/confluence_cloud.py b/backend/ee/onyx/server/oauth/confluence_cloud.py new file mode 100644 index 000000000..22fd23f98 --- /dev/null +++ b/backend/ee/onyx/server/oauth/confluence_cloud.py @@ -0,0 +1,361 @@ +import base64 +import uuid +from datetime import datetime +from datetime import timedelta +from datetime import timezone +from typing import Any +from typing import cast + +import requests +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from pydantic import ValidationError +from sqlalchemy.orm import Session + +from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET +from ee.onyx.server.oauth.api_router import router +from onyx.auth.users import current_admin_user +from onyx.configs.app_configs import DEV_MODE +from onyx.configs.app_configs import WEB_DOMAIN +from onyx.configs.constants import DocumentSource +from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL +from onyx.db.credentials import create_credential +from onyx.db.credentials import fetch_credential_by_id_for_user +from onyx.db.credentials import update_credential_json +from onyx.db.engine import get_current_tenant_id +from onyx.db.engine import get_session +from onyx.db.models import User +from onyx.redis.redis_pool import get_redis_client +from onyx.server.documents.models import CredentialBase +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +class ConfluenceCloudOAuth: + # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/ + + class OAuthSession(BaseModel): + """Stored in redis to be looked up on callback""" + + email: str + redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + + class TokenResponse(BaseModel): + access_token: str + expires_in: int + token_type: str + refresh_token: str + scope: str + + class AccessibleResources(BaseModel): + id: str + name: str + url: str + scopes: list[str] + avatarUrl: str + + CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID + CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET + TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL + + ACCESSIBLE_RESOURCE_URL = ( + "https://api.atlassian.com/oauth/token/accessible-resources" + ) + + # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/ + CONFLUENCE_OAUTH_SCOPE = ( + # classic scope + "read:confluence-space.summary%20" + "read:confluence-props%20" + "read:confluence-content.all%20" + "read:confluence-content.summary%20" + "read:confluence-content.permission%20" + "read:confluence-user%20" + "read:confluence-groups%20" + "readonly:content.attachment:confluence%20" + "search:confluence%20" + # granular scope + "read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api + "offline_access" + ) + + REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback" + DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + + # eventually for Confluence Data Center + # oauth_url = ( + # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}" + # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}" + # f"&redirect_uri={redirectme_uri}" + # ) + + @classmethod + def generate_oauth_url(cls, state: str) -> str: + return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) + + @classmethod + def generate_dev_oauth_url(cls, state: str) -> str: + """dev mode workaround for localhost testing + - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https + """ + return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) + + @classmethod + def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: + # https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code + + url = ( + "https://auth.atlassian.com/authorize" + f"?audience=api.atlassian.com" + f"&client_id={cls.CLIENT_ID}" + f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}" + f"&redirect_uri={redirect_uri}" + f"&state={state}" + "&response_type=code" + "&prompt=consent" + ) + return url + + @classmethod + def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: + """Temporary state to store in redis. to be looked up on auth response. + Returns a json string. + """ + session = ConfluenceCloudOAuth.OAuthSession( + email=email, redirect_on_success=redirect_on_success + ) + return session.model_dump_json() + + @classmethod + def parse_session(cls, session_json: str) -> OAuthSession: + session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json) + return session + + @classmethod + def generate_finalize_url(cls, credential_id: int) -> str: + return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}" + + +@router.post("/connector/confluence/callback") +def confluence_oauth_callback( + code: str, + state: str, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """Handles the backend logic for the frontend page that the user is redirected to + after visiting the oauth authorization url.""" + + if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET: + raise HTTPException( + status_code=500, + detail="Confluence Cloud client ID or client secret is not configured.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # recover the state + padded_state = state + "=" * ( + -len(state) % 4 + ) # Add padding back (Base64 decoding requires padding) + uuid_bytes = base64.urlsafe_b64decode( + padded_state + ) # Decode the Base64 string back to bytes + + # Convert bytes back to a UUID + oauth_uuid = uuid.UUID(bytes=uuid_bytes) + oauth_uuid_str = str(oauth_uuid) + + r_key = f"da_oauth:{oauth_uuid_str}" + + session_json_bytes = cast(bytes, r.get(r_key)) + if not session_json_bytes: + raise HTTPException( + status_code=400, + detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}", + ) + + session_json = session_json_bytes.decode("utf-8") + try: + session = ConfluenceCloudOAuth.parse_session(session_json) + + if not DEV_MODE: + redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI + else: + redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI + + # Exchange the authorization code for an access token + response = requests.post( + ConfluenceCloudOAuth.TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": ConfluenceCloudOAuth.CLIENT_ID, + "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET, + "code": code, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + + token_response: ConfluenceCloudOAuth.TokenResponse | None = None + + try: + token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json( + response.text + ) + except Exception: + raise RuntimeError( + "Confluence Cloud OAuth failed during code/token exchange." + ) + + now = datetime.now(timezone.utc) + expires_at = now + timedelta(seconds=token_response.expires_in) + + credential_info = CredentialBase( + credential_json={ + "confluence_access_token": token_response.access_token, + "confluence_refresh_token": token_response.refresh_token, + "created_at": now.isoformat(), + "expires_at": expires_at.isoformat(), + "expires_in": token_response.expires_in, + "scope": token_response.scope, + }, + admin_public=True, + source=DocumentSource.CONFLUENCE, + name="Confluence Cloud OAuth", + ) + + credential = create_credential(credential_info, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Confluence Cloud OAuth: {str(e)}", + }, + ) + finally: + r.delete(r_key) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Confluence Cloud OAuth completed successfully.", + "finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id), + "redirect_on_success": session.redirect_on_success, + } + ) + + +@router.get("/connector/confluence/accessible-resources") +def confluence_oauth_accessible_resources( + credential_id: int, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """Atlassian's API is weird and does not supply us with enough info to be in a + usable state after authorizing. All API's require a cloud id. We have to list + the accessible resources/sites and let the user choose which site to use.""" + + credential = fetch_credential_by_id_for_user(credential_id, user, db_session) + if not credential: + raise HTTPException(400, f"Credential {credential_id} not found.") + + credential_dict = credential.credential_json + access_token = credential_dict["confluence_access_token"] + + try: + # Exchange the authorization code for an access token + response = requests.get( + ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL, + headers={ + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + }, + ) + + response.raise_for_status() + accessible_resources_data = response.json() + + # Validate the list of AccessibleResources + try: + accessible_resources = [ + ConfluenceCloudOAuth.AccessibleResources(**resource) + for resource in accessible_resources_data + ] + except ValidationError as e: + raise RuntimeError(f"Failed to parse accessible resources: {e}") + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}", + }, + ) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Confluence Cloud get accessible resources completed successfully.", + "accessible_resources": [ + resource.model_dump() for resource in accessible_resources + ], + } + ) + + +@router.post("/connector/confluence/finalize") +def confluence_oauth_finalize( + credential_id: int, + cloud_id: str, + cloud_name: str, + cloud_url: str, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + """Saves the info for the selected cloud site to the credential. + This is the final step in the confluence oauth flow where after the traditional + OAuth process, the user has to select a site to associate with the credentials. + After this, the credential is usable.""" + + credential = fetch_credential_by_id_for_user(credential_id, user, db_session) + if not credential: + raise HTTPException( + status_code=400, + detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.", + ) + + new_credential_json: dict[str, Any] = dict(credential.credential_json) + new_credential_json["cloud_id"] = cloud_id + new_credential_json["cloud_name"] = cloud_name + new_credential_json["wiki_base"] = cloud_url + + try: + update_credential_json(credential_id, new_credential_json, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Confluence Cloud OAuth: {str(e)}", + }, + ) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Confluence Cloud OAuth finalized successfully.", + "redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence", + } + ) diff --git a/backend/ee/onyx/server/oauth/google_drive.py b/backend/ee/onyx/server/oauth/google_drive.py new file mode 100644 index 000000000..68f224c76 --- /dev/null +++ b/backend/ee/onyx/server/oauth/google_drive.py @@ -0,0 +1,229 @@ +import base64 +import json +import uuid +from typing import Any +from typing import cast + +import requests +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET +from ee.onyx.server.oauth.api_router import router +from onyx.auth.users import current_admin_user +from onyx.configs.app_configs import DEV_MODE +from onyx.configs.app_configs import WEB_DOMAIN +from onyx.configs.constants import DocumentSource +from onyx.connectors.google_utils.google_auth import get_google_oauth_creds +from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials +from onyx.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_AUTHENTICATION_METHOD, +) +from onyx.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_DICT_TOKEN_KEY, +) +from onyx.connectors.google_utils.shared_constants import ( + DB_CREDENTIALS_PRIMARY_ADMIN_KEY, +) +from onyx.connectors.google_utils.shared_constants import ( + GoogleOAuthAuthenticationMethod, +) +from onyx.db.credentials import create_credential +from onyx.db.engine import get_current_tenant_id +from onyx.db.engine import get_session +from onyx.db.models import User +from onyx.redis.redis_pool import get_redis_client +from onyx.server.documents.models import CredentialBase + + +class GoogleDriveOAuth: + # https://developers.google.com/identity/protocols/oauth2 + # https://developers.google.com/identity/protocols/oauth2/web-server + + class OAuthSession(BaseModel): + """Stored in redis to be looked up on callback""" + + email: str + redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + + CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID + CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET + + TOKEN_URL = "https://oauth2.googleapis.com/token" + + # SCOPE is per https://docs.danswer.dev/connectors/google-drive + # TODO: Merge with or use google_utils.GOOGLE_SCOPES + SCOPE = ( + "https://www.googleapis.com/auth/drive.readonly%20" + "https://www.googleapis.com/auth/drive.metadata.readonly%20" + "https://www.googleapis.com/auth/admin.directory.user.readonly%20" + "https://www.googleapis.com/auth/admin.directory.group.readonly" + ) + + REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback" + DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + + @classmethod + def generate_oauth_url(cls, state: str) -> str: + return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) + + @classmethod + def generate_dev_oauth_url(cls, state: str) -> str: + """dev mode workaround for localhost testing + - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https + """ + + return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) + + @classmethod + def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: + # without prompt=consent, a refresh token is only issued the first time the user approves + url = ( + f"https://accounts.google.com/o/oauth2/v2/auth" + f"?client_id={cls.CLIENT_ID}" + f"&redirect_uri={redirect_uri}" + "&response_type=code" + f"&scope={cls.SCOPE}" + "&access_type=offline" + f"&state={state}" + "&prompt=consent" + ) + return url + + @classmethod + def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: + """Temporary state to store in redis. to be looked up on auth response. + Returns a json string. + """ + session = GoogleDriveOAuth.OAuthSession( + email=email, redirect_on_success=redirect_on_success + ) + return session.model_dump_json() + + @classmethod + def parse_session(cls, session_json: str) -> OAuthSession: + session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json) + return session + + +@router.post("/connector/google-drive/callback") +def handle_google_drive_oauth_callback( + code: str, + state: str, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET: + raise HTTPException( + status_code=500, + detail="Google Drive client ID or client secret is not configured.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # recover the state + padded_state = state + "=" * ( + -len(state) % 4 + ) # Add padding back (Base64 decoding requires padding) + uuid_bytes = base64.urlsafe_b64decode( + padded_state + ) # Decode the Base64 string back to bytes + + # Convert bytes back to a UUID + oauth_uuid = uuid.UUID(bytes=uuid_bytes) + oauth_uuid_str = str(oauth_uuid) + + r_key = f"da_oauth:{oauth_uuid_str}" + + session_json_bytes = cast(bytes, r.get(r_key)) + if not session_json_bytes: + raise HTTPException( + status_code=400, + detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}", + ) + + session_json = session_json_bytes.decode("utf-8") + try: + session = GoogleDriveOAuth.parse_session(session_json) + + if not DEV_MODE: + redirect_uri = GoogleDriveOAuth.REDIRECT_URI + else: + redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI + + # Exchange the authorization code for an access token + response = requests.post( + GoogleDriveOAuth.TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": GoogleDriveOAuth.CLIENT_ID, + "client_secret": GoogleDriveOAuth.CLIENT_SECRET, + "code": code, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + + response.raise_for_status() + + authorization_response: dict[str, Any] = response.json() + + # the connector wants us to store the json in its authorized_user_info format + # returned from OAuthCredentials.get_authorized_user_info(). + # So refresh immediately via get_google_oauth_creds with the params filled in + # from fields in authorization_response to get the json we need + authorized_user_info = {} + authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID + authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET + authorized_user_info["refresh_token"] = authorization_response["refresh_token"] + + token_json_str = json.dumps(authorized_user_info) + oauth_creds = get_google_oauth_creds( + token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE + ) + if not oauth_creds: + raise RuntimeError("get_google_oauth_creds returned None.") + + # save off the credentials + oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds) + + credential_dict: dict[str, str] = {} + credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str + credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email + credential_dict[ + DB_CREDENTIALS_AUTHENTICATION_METHOD + ] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value + + credential_info = CredentialBase( + credential_json=credential_dict, + admin_public=True, + source=DocumentSource.GOOGLE_DRIVE, + name="OAuth (interactive)", + ) + + create_credential(credential_info, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Google Drive OAuth: {str(e)}", + }, + ) + finally: + r.delete(r_key) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Google Drive OAuth completed successfully.", + "finalize_url": None, + "redirect_on_success": session.redirect_on_success, + } + ) diff --git a/backend/ee/onyx/server/oauth/slack.py b/backend/ee/onyx/server/oauth/slack.py new file mode 100644 index 000000000..e8c5c3063 --- /dev/null +++ b/backend/ee/onyx/server/oauth/slack.py @@ -0,0 +1,197 @@ +import base64 +import uuid +from typing import cast + +import requests +from fastapi import Depends +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET +from ee.onyx.server.oauth.api_router import router +from onyx.auth.users import current_admin_user +from onyx.configs.app_configs import DEV_MODE +from onyx.configs.app_configs import WEB_DOMAIN +from onyx.configs.constants import DocumentSource +from onyx.db.credentials import create_credential +from onyx.db.engine import get_current_tenant_id +from onyx.db.engine import get_session +from onyx.db.models import User +from onyx.redis.redis_pool import get_redis_client +from onyx.server.documents.models import CredentialBase + + +class SlackOAuth: + # https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth + # Example: https://api.slack.com/authentication/oauth-v2#exchanging + + class OAuthSession(BaseModel): + """Stored in redis to be looked up on callback""" + + email: str + redirect_on_success: str | None # Where to send the user if OAuth flow succeeds + + CLIENT_ID = OAUTH_SLACK_CLIENT_ID + CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET + + TOKEN_URL = "https://slack.com/api/oauth.v2.access" + + # SCOPE is per https://docs.danswer.dev/connectors/slack + BOT_SCOPE = ( + "channels:history," + "channels:read," + "groups:history," + "groups:read," + "channels:join," + "im:history," + "users:read," + "users:read.email," + "usergroups:read" + ) + + REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback" + DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" + + @classmethod + def generate_oauth_url(cls, state: str) -> str: + return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) + + @classmethod + def generate_dev_oauth_url(cls, state: str) -> str: + """dev mode workaround for localhost testing + - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https + """ + + return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) + + @classmethod + def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: + url = ( + f"https://slack.com/oauth/v2/authorize" + f"?client_id={cls.CLIENT_ID}" + f"&redirect_uri={redirect_uri}" + f"&scope={cls.BOT_SCOPE}" + f"&state={state}" + ) + return url + + @classmethod + def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: + """Temporary state to store in redis. to be looked up on auth response. + Returns a json string. + """ + session = SlackOAuth.OAuthSession( + email=email, redirect_on_success=redirect_on_success + ) + return session.model_dump_json() + + @classmethod + def parse_session(cls, session_json: str) -> OAuthSession: + session = SlackOAuth.OAuthSession.model_validate_json(session_json) + return session + + +@router.post("/connector/slack/callback") +def handle_slack_oauth_callback( + code: str, + state: str, + user: User = Depends(current_admin_user), + db_session: Session = Depends(get_session), + tenant_id: str | None = Depends(get_current_tenant_id), +) -> JSONResponse: + if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET: + raise HTTPException( + status_code=500, + detail="Slack client ID or client secret is not configured.", + ) + + r = get_redis_client(tenant_id=tenant_id) + + # recover the state + padded_state = state + "=" * ( + -len(state) % 4 + ) # Add padding back (Base64 decoding requires padding) + uuid_bytes = base64.urlsafe_b64decode( + padded_state + ) # Decode the Base64 string back to bytes + + # Convert bytes back to a UUID + oauth_uuid = uuid.UUID(bytes=uuid_bytes) + oauth_uuid_str = str(oauth_uuid) + + r_key = f"da_oauth:{oauth_uuid_str}" + + session_json_bytes = cast(bytes, r.get(r_key)) + if not session_json_bytes: + raise HTTPException( + status_code=400, + detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}", + ) + + session_json = session_json_bytes.decode("utf-8") + try: + session = SlackOAuth.parse_session(session_json) + + if not DEV_MODE: + redirect_uri = SlackOAuth.REDIRECT_URI + else: + redirect_uri = SlackOAuth.DEV_REDIRECT_URI + + # Exchange the authorization code for an access token + response = requests.post( + SlackOAuth.TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "client_id": SlackOAuth.CLIENT_ID, + "client_secret": SlackOAuth.CLIENT_SECRET, + "code": code, + "redirect_uri": redirect_uri, + }, + ) + + response_data = response.json() + + if not response_data.get("ok"): + raise HTTPException( + status_code=400, + detail=f"Slack OAuth failed: {response_data.get('error')}", + ) + + # Extract token and team information + access_token: str = response_data.get("access_token") + team_id: str = response_data.get("team", {}).get("id") + authed_user_id: str = response_data.get("authed_user", {}).get("id") + + credential_info = CredentialBase( + credential_json={"slack_bot_token": access_token}, + admin_public=True, + source=DocumentSource.SLACK, + name="Slack OAuth", + ) + + create_credential(credential_info, user, db_session) + except Exception as e: + return JSONResponse( + status_code=500, + content={ + "success": False, + "message": f"An error occurred during Slack OAuth: {str(e)}", + }, + ) + finally: + r.delete(r_key) + + # return the result + return JSONResponse( + content={ + "success": True, + "message": "Slack OAuth completed successfully.", + "finalize_url": None, + "redirect_on_success": session.redirect_on_success, + "team_id": team_id, + "authed_user_id": authed_user_id, + } + ) diff --git a/backend/ee/onyx/server/query_history/api.py b/backend/ee/onyx/server/query_history/api.py index 6921c280d..11595ae7a 100644 --- a/backend/ee/onyx/server/query_history/api.py +++ b/backend/ee/onyx/server/query_history/api.py @@ -138,6 +138,7 @@ def get_user_chat_sessions( name=chat.description, persona_id=chat.persona_id, time_created=chat.time_created.isoformat(), + time_updated=chat.time_updated.isoformat(), shared_status=chat.shared_status, folder_id=chat.folder_id, current_alternate_model=chat.current_alternate_model, diff --git a/backend/ee/onyx/server/tenants/billing.py b/backend/ee/onyx/server/tenants/billing.py index 98de75a9a..7c5ae8534 100644 --- a/backend/ee/onyx/server/tenants/billing.py +++ b/backend/ee/onyx/server/tenants/billing.py @@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY from ee.onyx.server.tenants.access import generate_data_plane_token from ee.onyx.server.tenants.models import BillingInformation +from ee.onyx.server.tenants.models import SubscriptionStatusResponse from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL from onyx.utils.logger import setup_logger @@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict: return response.json() -def fetch_billing_information(tenant_id: str) -> BillingInformation: +def fetch_billing_information( + tenant_id: str, +) -> BillingInformation | SubscriptionStatusResponse: logger.info("Fetching billing information") token = generate_data_plane_token() headers = { @@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation: params = {"tenant_id": tenant_id} response = requests.get(url, headers=headers, params=params) response.raise_for_status() - billing_info = BillingInformation(**response.json()) - return billing_info + + response_data = response.json() + + # Check if the response indicates no subscription + if ( + isinstance(response_data, dict) + and "subscribed" in response_data + and not response_data["subscribed"] + ): + return SubscriptionStatusResponse(**response_data) + + # Otherwise, parse as BillingInformation + return BillingInformation(**response_data) def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription: diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 8521cd001..9215042f1 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -78,7 +78,7 @@ class CloudEmbedding: self._closed = False async def _embed_openai( - self, texts: list[str], model: str | None + self, texts: list[str], model: str | None, reduced_dimension: int | None ) -> list[Embedding]: if not model: model = DEFAULT_OPENAI_MODEL @@ -91,7 +91,11 @@ class CloudEmbedding: final_embeddings: list[Embedding] = [] try: for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN): - response = await client.embeddings.create(input=text_batch, model=model) + response = await client.embeddings.create( + input=text_batch, + model=model, + dimensions=reduced_dimension or openai.NOT_GIVEN, + ) final_embeddings.extend( [embedding.embedding for embedding in response.data] ) @@ -223,9 +227,10 @@ class CloudEmbedding: text_type: EmbedTextType, model_name: str | None = None, deployment_name: str | None = None, + reduced_dimension: int | None = None, ) -> list[Embedding]: if self.provider == EmbeddingProvider.OPENAI: - return await self._embed_openai(texts, model_name) + return await self._embed_openai(texts, model_name, reduced_dimension) elif self.provider == EmbeddingProvider.AZURE: return await self._embed_azure(texts, f"azure/{deployment_name}") elif self.provider == EmbeddingProvider.LITELLM: @@ -326,6 +331,7 @@ async def embed_text( prefix: str | None, api_url: str | None, api_version: str | None, + reduced_dimension: int | None, gpu_type: str = "UNKNOWN", ) -> list[Embedding]: if not all(texts): @@ -369,6 +375,7 @@ async def embed_text( model_name=model_name, deployment_name=deployment_name, text_type=text_type, + reduced_dimension=reduced_dimension, ) if any(embedding is None for embedding in embeddings): @@ -508,6 +515,7 @@ async def process_embed_request( text_type=embed_request.text_type, api_url=embed_request.api_url, api_version=embed_request.api_version, + reduced_dimension=embed_request.reduced_dimension, prefix=prefix, gpu_type=gpu_type, ) diff --git a/backend/onyx/auth/users.py b/backend/onyx/auth/users.py index 07eda2b9b..06ca53ff2 100644 --- a/backend/onyx/auth/users.py +++ b/backend/onyx/auth/users.py @@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): "refresh_token": refresh_token, } - user: User + user: User | None = None try: # Attempt to get user by OAuth account @@ -420,15 +420,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): except exceptions.UserNotExists: try: # Attempt to get user by email - user = cast(User, await self.user_db.get_by_email(account_email)) + user = await self.user_db.get_by_email(account_email) if not associate_by_email: raise exceptions.UserAlreadyExists() - user = await self.user_db.add_oauth_account( - user, oauth_account_dict - ) + # Make sure user is not None before adding OAuth account + if user is not None: + user = await self.user_db.add_oauth_account( + user, oauth_account_dict + ) + else: + # This shouldn't happen since get_by_email would raise UserNotExists + # but adding as a safeguard + raise exceptions.UserNotExists() - # If user not found by OAuth account or email, create a new user except exceptions.UserNotExists: password = self.password_helper.generate() user_dict = { @@ -439,26 +444,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): user = await self.user_db.create(user_dict) - # Explicitly set the Postgres schema for this session to ensure - # OAuth account creation happens in the correct tenant schema - - # Add OAuth account - await self.user_db.add_oauth_account(user, oauth_account_dict) - await self.on_after_register(user, request) + # Add OAuth account only if user creation was successful + if user is not None: + await self.user_db.add_oauth_account(user, oauth_account_dict) + await self.on_after_register(user, request) + else: + raise HTTPException( + status_code=500, detail="Failed to create user account" + ) else: - for existing_oauth_account in user.oauth_accounts: - if ( - existing_oauth_account.account_id == account_id - and existing_oauth_account.oauth_name == oauth_name - ): - user = await self.user_db.update_oauth_account( - user, - # NOTE: OAuthAccount DOES implement the OAuthAccountProtocol - # but the type checker doesn't know that :( - existing_oauth_account, # type: ignore - oauth_account_dict, - ) + # User exists, update OAuth account if needed + if user is not None: # Add explicit check + for existing_oauth_account in user.oauth_accounts: + if ( + existing_oauth_account.account_id == account_id + and existing_oauth_account.oauth_name == oauth_name + ): + user = await self.user_db.update_oauth_account( + user, + # NOTE: OAuthAccount DOES implement the OAuthAccountProtocol + # but the type checker doesn't know that :( + existing_oauth_account, # type: ignore + oauth_account_dict, + ) + + # Ensure user is not None before proceeding + if user is None: + raise HTTPException( + status_code=500, detail="Failed to authenticate or create user" + ) # NOTE: Most IdPs have very short expiry times, and we don't want to force the user to # re-authenticate that frequently, so by default this is disabled diff --git a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py index 0a3c395d9..bb445affc 100644 --- a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py @@ -423,7 +423,7 @@ def connector_external_group_sync_generator_task( ) external_user_groups: list[ExternalUserGroup] = [] try: - external_user_groups = ext_group_sync_func(cc_pair) + external_user_groups = ext_group_sync_func(tenant_id, cc_pair) except ConnectorValidationError as e: msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}" update_connector_credential_pair( diff --git a/backend/onyx/background/celery/tasks/indexing/tasks.py b/backend/onyx/background/celery/tasks/indexing/tasks.py index 1c00e0590..16d65f831 100644 --- a/backend/onyx/background/celery/tasks/indexing/tasks.py +++ b/backend/onyx/background/celery/tasks/indexing/tasks.py @@ -23,9 +23,9 @@ from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_utils import httpx_init_vespa_pool -from onyx.background.celery.tasks.indexing.utils import _should_index from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids from onyx.background.celery.tasks.indexing.utils import IndexingCallback +from onyx.background.celery.tasks.indexing.utils import should_index from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint @@ -61,7 +61,7 @@ from onyx.db.index_attempt import mark_attempt_failed from onyx.db.search_settings import get_active_search_settings_list from onyx.db.search_settings import get_current_search_settings from onyx.db.session import get_session_with_current_tenant -from onyx.db.swap_index import check_index_swap +from onyx.db.swap_index import check_and_perform_index_swap from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder from onyx.redis.redis_connector import RedisConnector @@ -406,7 +406,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None: # check for search settings swap with get_session_with_current_tenant() as db_session: - old_search_settings = check_index_swap(db_session=db_session) + old_search_settings = check_and_perform_index_swap(db_session=db_session) current_search_settings = get_current_search_settings(db_session) # So that the first time users aren't surprised by really slow speed of first # batch of documents indexed @@ -439,6 +439,15 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None: with get_session_with_current_tenant() as db_session: search_settings_list = get_active_search_settings_list(db_session) for search_settings_instance in search_settings_list: + # skip non-live search settings that don't have background reindex enabled + # those should just auto-change to live shortly after creation without + # requiring any indexing till that point + if ( + not search_settings_instance.status.is_current() + and not search_settings_instance.background_reindex_enabled + ): + continue + redis_connector_index = redis_connector.new_index( search_settings_instance.id ) @@ -456,23 +465,18 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None: cc_pair.id, search_settings_instance.id, db_session ) - search_settings_primary = False - if search_settings_instance.id == search_settings_list[0].id: - search_settings_primary = True - - if not _should_index( + if not should_index( cc_pair=cc_pair, last_index=last_attempt, search_settings_instance=search_settings_instance, - search_settings_primary=search_settings_primary, secondary_index_building=len(search_settings_list) > 1, db_session=db_session, ): continue reindex = False - if search_settings_instance.id == search_settings_list[0].id: - # the indexing trigger is only checked and cleared with the primary search settings + if search_settings_instance.status.is_current(): + # the indexing trigger is only checked and cleared with the current search settings if cc_pair.indexing_trigger is not None: if cc_pair.indexing_trigger == IndexingMode.REINDEX: reindex = True diff --git a/backend/onyx/background/celery/tasks/indexing/utils.py b/backend/onyx/background/celery/tasks/indexing/utils.py index 1a39f0912..39a2268a9 100644 --- a/backend/onyx/background/celery/tasks/indexing/utils.py +++ b/backend/onyx/background/celery/tasks/indexing/utils.py @@ -346,11 +346,10 @@ def validate_indexing_fences( return -def _should_index( +def should_index( cc_pair: ConnectorCredentialPair, last_index: IndexAttempt | None, search_settings_instance: SearchSettings, - search_settings_primary: bool, secondary_index_building: bool, db_session: Session, ) -> bool: @@ -415,9 +414,9 @@ def _should_index( ): return False - if search_settings_primary: + if search_settings_instance.status.is_current(): if cc_pair.indexing_trigger is not None: - # if a manual indexing trigger is on the cc pair, honor it for primary search settings + # if a manual indexing trigger is on the cc pair, honor it for live search settings return True # if no attempt has ever occurred, we should index regardless of refresh_freq diff --git a/backend/onyx/background/error_logging.py b/backend/onyx/background/error_logging.py index c5abef2fb..1a1bdcebd 100644 --- a/backend/onyx/background/error_logging.py +++ b/backend/onyx/background/error_logging.py @@ -11,10 +11,27 @@ def emit_background_error( """Currently just saves a row in the background_errors table. In the future, could create notifications based on the severity.""" - with get_session_with_current_tenant() as db_session: - try: + error_message = "" + + # try to write to the db, but handle IntegrityError specifically + try: + with get_session_with_current_tenant() as db_session: create_background_error(db_session, message, cc_pair_id) - except IntegrityError as e: - # Log an error if the cc_pair_id was deleted or any other exception occurs - error_message = f"Failed to create background error: {str(e)}. Original message: {message}" + except IntegrityError as e: + # Log an error if the cc_pair_id was deleted or any other exception occurs + error_message = ( + f"Failed to create background error: {str(e)}. Original message: {message}" + ) + except Exception: + pass + + if not error_message: + return + + # if we get here from an IntegrityError, try to write the error message to the db + # we need a new session because the first session is now invalid + try: + with get_session_with_current_tenant() as db_session: create_background_error(db_session, error_message, None) + except Exception: + pass diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index ddca17b2f..53502c3c5 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -93,10 +93,11 @@ def _get_connector_runner( runnable_connector.validate_connector_settings() except Exception as e: - logger.exception(f"Unable to instantiate connector due to {e}") - + logger.exception("Unable to instantiate connector.") # since we failed to even instantiate the connector, we pause the CCPair since - # it will never succeed. Sometimes there are cases where the connector will + # it will never succeed + + # Sometimes there are cases where the connector will # intermittently fail to initialize in which case we should pass in # leave_connector_active=True to allow it to continue. # For example, if there is nightly maintenance on a Confluence Server instance, diff --git a/backend/onyx/connectors/confluence/connector.py b/backend/onyx/connectors/confluence/connector.py index 43d00c42b..29e279014 100644 --- a/backend/onyx/connectors/confluence/connector.py +++ b/backend/onyx/connectors/confluence/connector.py @@ -11,17 +11,20 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource -from onyx.connectors.confluence.onyx_confluence import build_confluence_client +from onyx.connectors.confluence.onyx_confluence import attachment_to_content +from onyx.connectors.confluence.onyx_confluence import ( + extract_text_from_confluence_html, +) from onyx.connectors.confluence.onyx_confluence import OnyxConfluence -from onyx.connectors.confluence.utils import attachment_to_content from onyx.connectors.confluence.utils import build_confluence_document_id from onyx.connectors.confluence.utils import datetime_from_string -from onyx.connectors.confluence.utils import extract_text_from_confluence_html from onyx.connectors.confluence.utils import validate_attachment_filetype from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedError +from onyx.connectors.interfaces import CredentialsConnector +from onyx.connectors.interfaces import CredentialsProviderInterface from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import LoadConnector @@ -83,7 +86,9 @@ _FULL_EXTENSION_FILTER_STRING = "".join( ) -class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): +class ConfluenceConnector( + LoadConnector, PollConnector, SlimConnector, CredentialsConnector +): def __init__( self, wiki_base: str, @@ -102,7 +107,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): ) -> None: self.batch_size = batch_size self.continue_on_failure = continue_on_failure - self._confluence_client: OnyxConfluence | None = None self.is_cloud = is_cloud # Remove trailing slash from wiki_base if present @@ -137,6 +141,19 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): self.cql_label_filter = f" and label not in ({comma_separated_labels})" self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset)) + self.credentials_provider: CredentialsProviderInterface | None = None + + self.probe_kwargs = { + "max_backoff_retries": 6, + "max_backoff_seconds": 10, + } + + self.final_kwargs = { + "max_backoff_retries": 10, + "max_backoff_seconds": 60, + } + + self._confluence_client: OnyxConfluence | None = None @property def confluence_client(self) -> OnyxConfluence: @@ -144,15 +161,22 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): raise ConnectorMissingCredentialError("Confluence") return self._confluence_client - def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: - # see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py - # for a list of other hidden constructor args - self._confluence_client = build_confluence_client( - credentials=credentials, - is_cloud=self.is_cloud, - wiki_base=self.wiki_base, + def set_credentials_provider( + self, credentials_provider: CredentialsProviderInterface + ) -> None: + self.credentials_provider = credentials_provider + + # raises exception if there's a problem + confluence_client = OnyxConfluence( + self.is_cloud, self.wiki_base, credentials_provider ) - return None + confluence_client._probe_connection(**self.probe_kwargs) + confluence_client._initialize_connection(**self.final_kwargs) + + self._confluence_client = confluence_client + + def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: + raise NotImplementedError("Use set_credentials_provider with this connector.") def _construct_page_query( self, @@ -202,12 +226,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): return comment_string def _convert_object_to_document( - self, confluence_object: dict[str, Any] + self, + confluence_object: dict[str, Any], + parent_content_id: str | None = None, ) -> Document | None: """ Takes in a confluence object, extracts all metadata, and converts it into a document. If its a page, it extracts the text, adds the comments for the document text. If its an attachment, it just downloads the attachment and converts that into a document. + + parent_content_id: if the object is an attachment, specifies the content id that + the attachment is attached to """ # The url and the id are the same object_url = build_confluence_document_id( @@ -226,7 +255,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): object_text += self._get_comment_string_for_page_id(confluence_object["id"]) elif confluence_object["type"] == "attachment": object_text = attachment_to_content( - confluence_client=self.confluence_client, attachment=confluence_object + confluence_client=self.confluence_client, + attachment=confluence_object, + parent_content_id=parent_content_id, ) if object_text is None: @@ -302,7 +333,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): cql=attachment_query, expand=",".join(_ATTACHMENT_EXPANSION_FIELDS), ): - doc = self._convert_object_to_document(attachment) + doc = self._convert_object_to_document(attachment, confluence_page_id) if doc is not None: doc_batch.append(doc) if len(doc_batch) >= self.batch_size: diff --git a/backend/onyx/connectors/confluence/onyx_confluence.py b/backend/onyx/connectors/confluence/onyx_confluence.py index df28900bc..147ed82c6 100644 --- a/backend/onyx/connectors/confluence/onyx_confluence.py +++ b/backend/onyx/connectors/confluence/onyx_confluence.py @@ -1,19 +1,37 @@ -import math +import io +import json import time from collections.abc import Callable from collections.abc import Iterator +from datetime import datetime +from datetime import timedelta +from datetime import timezone from typing import Any from typing import cast from typing import TypeVar from urllib.parse import quote +import bs4 from atlassian import Confluence # type:ignore from pydantic import BaseModel +from redis import Redis from requests import HTTPError +from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID +from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET +from onyx.configs.app_configs import ( + CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD, +) +from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD +from onyx.connectors.confluence.utils import _handle_http_error +from onyx.connectors.confluence.utils import confluence_refresh_tokens from onyx.connectors.confluence.utils import get_start_param_from_url from onyx.connectors.confluence.utils import update_param_in_path -from onyx.connectors.exceptions import ConnectorValidationError +from onyx.connectors.confluence.utils import validate_attachment_filetype +from onyx.connectors.interfaces import CredentialsProviderInterface +from onyx.file_processing.extract_file_text import extract_file_text +from onyx.file_processing.html_utils import format_document_soup +from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger logger = setup_logger() @@ -22,12 +40,14 @@ logger = setup_logger() F = TypeVar("F", bound=Callable[..., Any]) -RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() - # https://jira.atlassian.com/browse/CONFCLOUD-76433 _PROBLEMATIC_EXPANSIONS = "body.storage.value" _REPLACEMENT_EXPANSIONS = "body.view.value" +_USER_NOT_FOUND = "Unknown Confluence User" +_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {} +_USER_EMAIL_CACHE: dict[str, str | None] = {} + class ConfluenceRateLimitError(Exception): pass @@ -43,124 +63,349 @@ class ConfluenceUser(BaseModel): type: str -def _handle_http_error(e: HTTPError, attempt: int) -> int: - MIN_DELAY = 2 - MAX_DELAY = 60 - STARTING_DELAY = 5 - BACKOFF = 2 - - # Check if the response or headers are None to avoid potential AttributeError - if e.response is None or e.response.headers is None: - logger.warning("HTTPError with `None` as response or as headers") - raise e - - if ( - e.response.status_code != 429 - and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower() - ): - raise e - - retry_after = None - - retry_after_header = e.response.headers.get("Retry-After") - if retry_after_header is not None: - try: - retry_after = int(retry_after_header) - if retry_after > MAX_DELAY: - logger.warning( - f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..." - ) - retry_after = MAX_DELAY - if retry_after < MIN_DELAY: - retry_after = MIN_DELAY - except ValueError: - pass - - if retry_after is not None: - logger.warning( - f"Rate limiting with retry header. Retrying after {retry_after} seconds..." - ) - delay = retry_after - else: - logger.warning( - "Rate limiting without retry header. Retrying with exponential backoff..." - ) - delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY) - - delay_until = math.ceil(time.monotonic() + delay) - return delay_until - - -# https://developer.atlassian.com/cloud/confluence/rate-limiting/ -# this uses the native rate limiting option provided by the -# confluence client and otherwise applies a simpler set of error handling -def handle_confluence_rate_limit(confluence_call: F) -> F: - def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: - MAX_RETRIES = 5 - - TIMEOUT = 600 - timeout_at = time.monotonic() + TIMEOUT - - for attempt in range(MAX_RETRIES): - if time.monotonic() > timeout_at: - raise TimeoutError( - f"Confluence call attempts took longer than {TIMEOUT} seconds." - ) - - try: - # we're relying more on the client to rate limit itself - # and applying our own retries in a more specific set of circumstances - return confluence_call(*args, **kwargs) - except HTTPError as e: - delay_until = _handle_http_error(e, attempt) - logger.warning( - f"HTTPError in confluence call. " - f"Retrying in {delay_until} seconds..." - ) - while time.monotonic() < delay_until: - # in the future, check a signal here to exit - time.sleep(1) - except AttributeError as e: - # Some error within the Confluence library, unclear why it fails. - # Users reported it to be intermittent, so just retry - if attempt == MAX_RETRIES - 1: - raise e - - logger.exception( - "Confluence Client raised an AttributeError. Retrying..." - ) - time.sleep(5) - - return cast(F, wrapped_call) - - _DEFAULT_PAGINATION_LIMIT = 1000 _MINIMUM_PAGINATION_LIMIT = 50 -class OnyxConfluence(Confluence): +class OnyxConfluence: """ - This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method. + This is a custom Confluence class that: + + A. overrides the default Confluence class to add a custom CQL method. + B. This is necessary because the default Confluence class does not properly support cql expansions. All methods are automatically wrapped with handle_confluence_rate_limit. """ - def __init__(self, url: str, *args: Any, **kwargs: Any) -> None: - super(OnyxConfluence, self).__init__(url, *args, **kwargs) - self._wrap_methods() + CREDENTIAL_PREFIX = "connector:confluence:credential" + CREDENTIAL_TTL = 300 # 5 min - def _wrap_methods(self) -> None: + def __init__( + self, + is_cloud: bool, + url: str, + credentials_provider: CredentialsProviderInterface, + ) -> None: + self._is_cloud = is_cloud + self._url = url.rstrip("/") + self._credentials_provider = credentials_provider + + self.redis_client: Redis | None = None + self.static_credentials: dict[str, Any] | None = None + if self._credentials_provider.is_dynamic(): + self.redis_client = get_redis_client( + tenant_id=credentials_provider.get_tenant_id() + ) + else: + self.static_credentials = self._credentials_provider.get_credentials() + + self._confluence = Confluence(url) + self.credential_key: str = ( + self.CREDENTIAL_PREFIX + + f":credential_{self._credentials_provider.get_provider_key()}" + ) + + self._kwargs: Any = None + + self.shared_base_kwargs = { + "api_version": "cloud" if is_cloud else "latest", + "backoff_and_retry": True, + "cloud": is_cloud, + } + + def _renew_credentials(self) -> tuple[dict[str, Any], bool]: + """credential_json - the current json credentials + Returns a tuple + 1. The up to date credentials + 2. True if the credentials were updated + + This method is intended to be used within a distributed lock. + Lock, call this, update credentials if the tokens were refreshed, then release """ - For each attribute that is callable (i.e., a method) and doesn't start with an underscore, - wrap it with handle_confluence_rate_limit. - """ - for attr_name in dir(self): - if callable(getattr(self, attr_name)) and not attr_name.startswith("_"): - setattr( - self, - attr_name, - handle_confluence_rate_limit(getattr(self, attr_name)), + # static credentials are preloaded, so no locking/redis required + if self.static_credentials: + return self.static_credentials, False + + if not self.redis_client: + raise RuntimeError("self.redis_client is None") + + # dynamic credentials need locking + # check redis first, then fallback to the DB + credential_raw = self.redis_client.get(self.credential_key) + if credential_raw is not None: + credential_bytes = cast(bytes, credential_raw) + credential_str = credential_bytes.decode("utf-8") + credential_json: dict[str, Any] = json.loads(credential_str) + else: + credential_json = self._credentials_provider.get_credentials() + + if "confluence_refresh_token" not in credential_json: + # static credentials ... cache them permanently and return + self.static_credentials = credential_json + return credential_json, False + + # check if we should refresh tokens. we're deciding to refresh halfway + # to expiration + now = datetime.now(timezone.utc) + created_at = datetime.fromisoformat(credential_json["created_at"]) + expires_in: int = credential_json["expires_in"] + renew_at = created_at + timedelta(seconds=expires_in // 2) + if now <= renew_at: + # cached/current credentials are reasonably up to date + return credential_json, False + + # we need to refresh + logger.info("Renewing Confluence Cloud credentials...") + new_credentials = confluence_refresh_tokens( + OAUTH_CONFLUENCE_CLOUD_CLIENT_ID, + OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET, + credential_json["cloud_id"], + credential_json["confluence_refresh_token"], + ) + + # store the new credentials to redis and to the db thru the provider + # redis: we use a 5 min TTL because we are given a 10 minute grace period + # when keys are rotated. it's easier to expire the cached credentials + # reasonably frequently rather than trying to handle strong synchronization + # between the db and redis everywhere the credentials might be updated + new_credential_str = json.dumps(new_credentials) + self.redis_client.set( + self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL + ) + self._credentials_provider.set_credentials(new_credentials) + + return new_credentials, True + + @staticmethod + def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]: + oauth2_dict: dict[str, Any] = {} + if "confluence_refresh_token" in credentials: + oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID + oauth2_dict["token"] = {} + oauth2_dict["token"]["access_token"] = credentials[ + "confluence_access_token" + ] + return oauth2_dict + + def _probe_connection( + self, + **kwargs: Any, + ) -> None: + merged_kwargs = {**self.shared_base_kwargs, **kwargs} + + with self._credentials_provider: + credentials, _ = self._renew_credentials() + + # probe connection with direct client, no retries + if "confluence_refresh_token" in credentials: + logger.info("Probing Confluence with OAuth Access Token.") + + oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict( + credentials ) + url = ( + f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}" + ) + confluence_client_with_minimal_retries = Confluence( + url=url, oauth2=oauth2_dict, **merged_kwargs + ) + else: + logger.info("Probing Confluence with Personal Access Token.") + url = self._url + if self._is_cloud: + confluence_client_with_minimal_retries = Confluence( + url=url, + username=credentials["confluence_username"], + password=credentials["confluence_access_token"], + **merged_kwargs, + ) + else: + confluence_client_with_minimal_retries = Confluence( + url=url, + token=credentials["confluence_access_token"], + **merged_kwargs, + ) + + spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1) + + # uncomment the following for testing + # the following is an attempt to retrieve the user's timezone + # Unfornately, all data is returned in UTC regardless of the user's time zone + # even tho CQL parses incoming times based on the user's time zone + # space_key = spaces["results"][0]["key"] + # space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space") + + if not spaces: + raise RuntimeError( + f"No spaces found at {url}! " + "Check your credentials and wiki_base and make sure " + "is_cloud is set correctly." + ) + + logger.info("Confluence probe succeeded.") + + def _initialize_connection( + self, + **kwargs: Any, + ) -> None: + """Called externally to init the connection in a thread safe manner.""" + merged_kwargs = {**self.shared_base_kwargs, **kwargs} + with self._credentials_provider: + credentials, _ = self._renew_credentials() + self._confluence = self._initialize_connection_helper( + credentials, **merged_kwargs + ) + self._kwargs = merged_kwargs + + def _initialize_connection_helper( + self, + credentials: dict[str, Any], + **kwargs: Any, + ) -> Confluence: + """Called internally to init the connection. Distributed locking + to prevent multiple threads from modifying the credentials + must be handled around this function.""" + + confluence = None + + # probe connection with direct client, no retries + if "confluence_refresh_token" in credentials: + logger.info("Connecting to Confluence Cloud with OAuth Access Token.") + + oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials) + url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}" + confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs) + else: + logger.info("Connecting to Confluence with Personal Access Token.") + if self._is_cloud: + confluence = Confluence( + url=self._url, + username=credentials["confluence_username"], + password=credentials["confluence_access_token"], + **kwargs, + ) + else: + confluence = Confluence( + url=self._url, + token=credentials["confluence_access_token"], + **kwargs, + ) + + return confluence + + # https://developer.atlassian.com/cloud/confluence/rate-limiting/ + # this uses the native rate limiting option provided by the + # confluence client and otherwise applies a simpler set of error handling + def _make_rate_limited_confluence_method( + self, name: str, credential_provider: CredentialsProviderInterface | None + ) -> Callable[..., Any]: + def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: + MAX_RETRIES = 5 + + TIMEOUT = 600 + timeout_at = time.monotonic() + TIMEOUT + + for attempt in range(MAX_RETRIES): + if time.monotonic() > timeout_at: + raise TimeoutError( + f"Confluence call attempts took longer than {TIMEOUT} seconds." + ) + + # we're relying more on the client to rate limit itself + # and applying our own retries in a more specific set of circumstances + try: + if credential_provider: + with credential_provider: + credentials, renewed = self._renew_credentials() + if renewed: + self._confluence = self._initialize_connection_helper( + credentials, **self._kwargs + ) + attr = getattr(self._confluence, name, None) + if attr is None: + # The underlying Confluence client doesn't have this attribute + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + return attr(*args, **kwargs) + else: + attr = getattr(self._confluence, name, None) + if attr is None: + # The underlying Confluence client doesn't have this attribute + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + return attr(*args, **kwargs) + + except HTTPError as e: + delay_until = _handle_http_error(e, attempt) + logger.warning( + f"HTTPError in confluence call. " + f"Retrying in {delay_until} seconds..." + ) + while time.monotonic() < delay_until: + # in the future, check a signal here to exit + time.sleep(1) + except AttributeError as e: + # Some error within the Confluence library, unclear why it fails. + # Users reported it to be intermittent, so just retry + if attempt == MAX_RETRIES - 1: + raise e + + logger.exception( + "Confluence Client raised an AttributeError. Retrying..." + ) + time.sleep(5) + + return wrapped_call + + # def _wrap_methods(self) -> None: + # """ + # For each attribute that is callable (i.e., a method) and doesn't start with an underscore, + # wrap it with handle_confluence_rate_limit. + # """ + # for attr_name in dir(self): + # if callable(getattr(self, attr_name)) and not attr_name.startswith("_"): + # setattr( + # self, + # attr_name, + # handle_confluence_rate_limit(getattr(self, attr_name)), + # ) + + # def _ensure_token_valid(self) -> None: + # if self._token_is_expired(): + # self._refresh_token() + # # Re-init the Confluence client with the originally stored args + # self._confluence = Confluence(self._url, *self._args, **self._kwargs) + + def __getattr__(self, name: str) -> Any: + """Dynamically intercept attribute/method access.""" + attr = getattr(self._confluence, name, None) + if attr is None: + # The underlying Confluence client doesn't have this attribute + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + # If it's not a method, just return it after ensuring token validity + if not callable(attr): + return attr + + # skip methods that start with "_" + if name.startswith("_"): + return attr + + # wrap the method with our retry handler + rate_limited_method: Callable[ + ..., Any + ] = self._make_rate_limited_confluence_method(name, self._credentials_provider) + + def wrapped_method(*args: Any, **kwargs: Any) -> Any: + return rate_limited_method(*args, **kwargs) + + return wrapped_method def _paginate_url( self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False @@ -507,63 +752,212 @@ class OnyxConfluence(Confluence): return response -def _validate_connector_configuration( - credentials: dict[str, Any], - is_cloud: bool, - wiki_base: str, -) -> None: - # test connection with direct client, no retries - confluence_client_with_minimal_retries = Confluence( - api_version="cloud" if is_cloud else "latest", - url=wiki_base.rstrip("/"), - username=credentials["confluence_username"] if is_cloud else None, - password=credentials["confluence_access_token"] if is_cloud else None, - token=credentials["confluence_access_token"] if not is_cloud else None, - backoff_and_retry=True, - max_backoff_retries=6, - max_backoff_seconds=10, +def get_user_email_from_username__server( + confluence_client: OnyxConfluence, user_name: str +) -> str | None: + global _USER_EMAIL_CACHE + if _USER_EMAIL_CACHE.get(user_name) is None: + try: + response = confluence_client.get_mobile_parameters(user_name) + email = response.get("email") + except Exception: + logger.warning(f"failed to get confluence email for {user_name}") + # For now, we'll just return None and log a warning. This means + # we will keep retrying to get the email every group sync. + email = None + # We may want to just return a string that indicates failure so we dont + # keep retrying + # email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}" + _USER_EMAIL_CACHE[user_name] = email + return _USER_EMAIL_CACHE[user_name] + + +def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str: + """Get Confluence Display Name based on the account-id or userkey value + + Args: + user_id (str): The user id (i.e: the account-id or userkey) + confluence_client (Confluence): The Confluence Client + + Returns: + str: The User Display Name. 'Unknown User' if the user is deactivated or not found + """ + global _USER_ID_TO_DISPLAY_NAME_CACHE + if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None: + try: + result = confluence_client.get_user_details_by_userkey(user_id) + found_display_name = result.get("displayName") + except Exception: + found_display_name = None + + if not found_display_name: + try: + result = confluence_client.get_user_details_by_accountid(user_id) + found_display_name = result.get("displayName") + except Exception: + found_display_name = None + + _USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name + + return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND + + +def attachment_to_content( + confluence_client: OnyxConfluence, + attachment: dict[str, Any], + parent_content_id: str | None = None, +) -> str | None: + """If it returns None, assume that we should skip this attachment.""" + if not validate_attachment_filetype(attachment): + return None + + if "api.atlassian.com" in confluence_client.url: + # https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get + if not parent_content_id: + logger.warning( + "parent_content_id is required to download attachments from Confluence Cloud!" + ) + return None + + download_link = ( + confluence_client.url + + f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download" + ) + else: + download_link = confluence_client.url + attachment["_links"]["download"] + + attachment_size = attachment["extensions"]["fileSize"] + if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD: + logger.warning( + f"Skipping {download_link} due to size. " + f"size={attachment_size} " + f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}" + ) + return None + + logger.info(f"_attachment_to_content - _session.get: link={download_link}") + + # why are we using session.get here? we probably won't retry these ... is that ok? + response = confluence_client._session.get(download_link) + if response.status_code != 200: + logger.warning( + f"Failed to fetch {download_link} with invalid status code {response.status_code}" + ) + return None + + extracted_text = extract_file_text( + io.BytesIO(response.content), + file_name=attachment["title"], + break_on_unprocessable=False, ) - spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1) + if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD: + logger.warning( + f"Skipping {download_link} due to char count. " + f"char count={len(extracted_text)} " + f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}" + ) + return None - # uncomment the following for testing - # the following is an attempt to retrieve the user's timezone - # Unfornately, all data is returned in UTC regardless of the user's time zone - # even tho CQL parses incoming times based on the user's time zone - # space_key = spaces["results"][0]["key"] - # space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space") + return extracted_text - if not spaces: - raise RuntimeError( - f"No spaces found at {wiki_base}! " - "Check your credentials and wiki_base and make sure " - "is_cloud is set correctly." + +def extract_text_from_confluence_html( + confluence_client: OnyxConfluence, + confluence_object: dict[str, Any], + fetched_titles: set[str], +) -> str: + """Parse a Confluence html page and replace the 'user Id' by the real + User Display Name + + Args: + confluence_object (dict): The confluence object as a dict + confluence_client (Confluence): Confluence client + fetched_titles (set[str]): The titles of the pages that have already been fetched + Returns: + str: loaded and formated Confluence page + """ + body = confluence_object["body"] + object_html = body.get("storage", body.get("view", {})).get("value") + + soup = bs4.BeautifulSoup(object_html, "html.parser") + for user in soup.findAll("ri:user"): + user_id = ( + user.attrs["ri:account-id"] + if "ri:account-id" in user.attrs + else user.get("ri:userkey") + ) + if not user_id: + logger.warning( + "ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}" + ) + continue + # Include @ sign for tagging, more clear for LLM + user.replaceWith("@" + _get_user(confluence_client, user_id)) + + for html_page_reference in soup.findAll("ac:structured-macro"): + # Here, we only want to process page within page macros + if html_page_reference.attrs.get("ac:name") != "include": + continue + + page_data = html_page_reference.find("ri:page") + if not page_data: + logger.warning( + f"Skipping retrieval of {html_page_reference} because because page data is missing" + ) + continue + + page_title = page_data.attrs.get("ri:content-title") + if not page_title: + # only fetch pages that have a title + logger.warning( + f"Skipping retrieval of {html_page_reference} because it has no title" + ) + continue + + if page_title in fetched_titles: + # prevent recursive fetching of pages + logger.debug(f"Skipping {page_title} because it has already been fetched") + continue + + fetched_titles.add(page_title) + + # Wrap this in a try-except because there are some pages that might not exist + try: + page_query = f"type=page and title='{quote(page_title)}'" + + page_contents: dict[str, Any] | None = None + # Confluence enforces title uniqueness, so we should only get one result here + for page in confluence_client.paginated_cql_retrieval( + cql=page_query, + expand="body.storage.value", + limit=1, + ): + page_contents = page + break + except Exception as e: + logger.warning( + f"Error getting page contents for object {confluence_object}: {e}" + ) + continue + + if not page_contents: + continue + + text_from_page = extract_text_from_confluence_html( + confluence_client=confluence_client, + confluence_object=page_contents, + fetched_titles=fetched_titles, ) + html_page_reference.replaceWith(text_from_page) -def build_confluence_client( - credentials: dict[str, Any], - is_cloud: bool, - wiki_base: str, -) -> OnyxConfluence: - try: - _validate_connector_configuration( - credentials=credentials, - is_cloud=is_cloud, - wiki_base=wiki_base, - ) - except Exception as e: - raise ConnectorValidationError(str(e)) + for html_link_body in soup.findAll("ac:link-body"): + # This extracts the text from inline links in the page so they can be + # represented in the document text as plain text + try: + text_from_link = html_link_body.text + html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})") + except Exception as e: + logger.warning(f"Error processing ac:link-body: {e}") - return OnyxConfluence( - api_version="cloud" if is_cloud else "latest", - # Remove trailing slash from wiki_base if present - url=wiki_base.rstrip("/"), - # passing in username causes issues for Confluence data center - username=credentials["confluence_username"] if is_cloud else None, - password=credentials["confluence_access_token"] if is_cloud else None, - token=credentials["confluence_access_token"] if not is_cloud else None, - backoff_and_retry=True, - max_backoff_retries=10, - max_backoff_seconds=60, - cloud=is_cloud, - ) + return format_document_soup(soup) diff --git a/backend/onyx/connectors/confluence/utils.py b/backend/onyx/connectors/confluence/utils.py index b77696645..801e24d4a 100644 --- a/backend/onyx/connectors/confluence/utils.py +++ b/backend/onyx/connectors/confluence/utils.py @@ -1,185 +1,38 @@ -import io +import math +import time +from collections.abc import Callable from datetime import datetime +from datetime import timedelta from datetime import timezone from typing import Any +from typing import cast from typing import TYPE_CHECKING +from typing import TypeVar from urllib.parse import parse_qs from urllib.parse import quote from urllib.parse import urlparse import bs4 +import requests +from pydantic import BaseModel -from onyx.configs.app_configs import ( - CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD, -) -from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD -from onyx.file_processing.extract_file_text import extract_file_text -from onyx.file_processing.html_utils import format_document_soup from onyx.utils.logger import setup_logger if TYPE_CHECKING: - from onyx.connectors.confluence.onyx_confluence import OnyxConfluence + pass logger = setup_logger() - -_USER_EMAIL_CACHE: dict[str, str | None] = {} +CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token" +RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() -def get_user_email_from_username__server( - confluence_client: "OnyxConfluence", user_name: str -) -> str | None: - global _USER_EMAIL_CACHE - if _USER_EMAIL_CACHE.get(user_name) is None: - try: - response = confluence_client.get_mobile_parameters(user_name) - email = response.get("email") - except Exception: - logger.warning(f"failed to get confluence email for {user_name}") - # For now, we'll just return None and log a warning. This means - # we will keep retrying to get the email every group sync. - email = None - # We may want to just return a string that indicates failure so we dont - # keep retrying - # email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}" - _USER_EMAIL_CACHE[user_name] = email - return _USER_EMAIL_CACHE[user_name] - - -_USER_NOT_FOUND = "Unknown Confluence User" -_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {} - - -def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str: - """Get Confluence Display Name based on the account-id or userkey value - - Args: - user_id (str): The user id (i.e: the account-id or userkey) - confluence_client (Confluence): The Confluence Client - - Returns: - str: The User Display Name. 'Unknown User' if the user is deactivated or not found - """ - global _USER_ID_TO_DISPLAY_NAME_CACHE - if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None: - try: - result = confluence_client.get_user_details_by_userkey(user_id) - found_display_name = result.get("displayName") - except Exception: - found_display_name = None - - if not found_display_name: - try: - result = confluence_client.get_user_details_by_accountid(user_id) - found_display_name = result.get("displayName") - except Exception: - found_display_name = None - - _USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name - - return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND - - -def extract_text_from_confluence_html( - confluence_client: "OnyxConfluence", - confluence_object: dict[str, Any], - fetched_titles: set[str], -) -> str: - """Parse a Confluence html page and replace the 'user Id' by the real - User Display Name - - Args: - confluence_object (dict): The confluence object as a dict - confluence_client (Confluence): Confluence client - fetched_titles (set[str]): The titles of the pages that have already been fetched - Returns: - str: loaded and formated Confluence page - """ - body = confluence_object["body"] - object_html = body.get("storage", body.get("view", {})).get("value") - - soup = bs4.BeautifulSoup(object_html, "html.parser") - for user in soup.findAll("ri:user"): - user_id = ( - user.attrs["ri:account-id"] - if "ri:account-id" in user.attrs - else user.get("ri:userkey") - ) - if not user_id: - logger.warning( - "ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}" - ) - continue - # Include @ sign for tagging, more clear for LLM - user.replaceWith("@" + _get_user(confluence_client, user_id)) - - for html_page_reference in soup.findAll("ac:structured-macro"): - # Here, we only want to process page within page macros - if html_page_reference.attrs.get("ac:name") != "include": - continue - - page_data = html_page_reference.find("ri:page") - if not page_data: - logger.warning( - f"Skipping retrieval of {html_page_reference} because because page data is missing" - ) - continue - - page_title = page_data.attrs.get("ri:content-title") - if not page_title: - # only fetch pages that have a title - logger.warning( - f"Skipping retrieval of {html_page_reference} because it has no title" - ) - continue - - if page_title in fetched_titles: - # prevent recursive fetching of pages - logger.debug(f"Skipping {page_title} because it has already been fetched") - continue - - fetched_titles.add(page_title) - - # Wrap this in a try-except because there are some pages that might not exist - try: - page_query = f"type=page and title='{quote(page_title)}'" - - page_contents: dict[str, Any] | None = None - # Confluence enforces title uniqueness, so we should only get one result here - for page in confluence_client.paginated_cql_retrieval( - cql=page_query, - expand="body.storage.value", - limit=1, - ): - page_contents = page - break - except Exception as e: - logger.warning( - f"Error getting page contents for object {confluence_object}: {e}" - ) - continue - - if not page_contents: - continue - - text_from_page = extract_text_from_confluence_html( - confluence_client=confluence_client, - confluence_object=page_contents, - fetched_titles=fetched_titles, - ) - - html_page_reference.replaceWith(text_from_page) - - for html_link_body in soup.findAll("ac:link-body"): - # This extracts the text from inline links in the page so they can be - # represented in the document text as plain text - try: - text_from_link = html_link_body.text - html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})") - except Exception as e: - logger.warning(f"Error processing ac:link-body: {e}") - - return format_document_soup(soup) +class TokenResponse(BaseModel): + access_token: str + expires_in: int + token_type: str + refresh_token: str + scope: str def validate_attachment_filetype(attachment: dict[str, Any]) -> bool: @@ -193,49 +46,6 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool: ] -def attachment_to_content( - confluence_client: "OnyxConfluence", - attachment: dict[str, Any], -) -> str | None: - """If it returns None, assume that we should skip this attachment.""" - if not validate_attachment_filetype(attachment): - return None - - download_link = confluence_client.url + attachment["_links"]["download"] - - attachment_size = attachment["extensions"]["fileSize"] - if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD: - logger.warning( - f"Skipping {download_link} due to size. " - f"size={attachment_size} " - f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}" - ) - return None - - logger.info(f"_attachment_to_content - _session.get: link={download_link}") - response = confluence_client._session.get(download_link) - if response.status_code != 200: - logger.warning( - f"Failed to fetch {download_link} with invalid status code {response.status_code}" - ) - return None - - extracted_text = extract_file_text( - io.BytesIO(response.content), - file_name=attachment["title"], - break_on_unprocessable=False, - ) - if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD: - logger.warning( - f"Skipping {download_link} due to char count. " - f"char count={len(extracted_text)} " - f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}" - ) - return None - - return extracted_text - - def build_confluence_document_id( base_url: str, content_url: str, is_cloud: bool ) -> str: @@ -284,6 +94,137 @@ def datetime_from_string(datetime_string: str) -> datetime: return datetime_object +def confluence_refresh_tokens( + client_id: str, client_secret: str, cloud_id: str, refresh_token: str +) -> dict[str, Any]: + # rotate the refresh and access token + # Note that access tokens are only good for an hour in confluence cloud, + # so we're going to have problems if the connector runs for longer + # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair + response = requests.post( + CONFLUENCE_OAUTH_TOKEN_URL, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "grant_type": "refresh_token", + "client_id": client_id, + "client_secret": client_secret, + "refresh_token": refresh_token, + }, + ) + + try: + token_response = TokenResponse.model_validate_json(response.text) + except Exception: + raise RuntimeError("Confluence Cloud token refresh failed.") + + now = datetime.now(timezone.utc) + expires_at = now + timedelta(seconds=token_response.expires_in) + + new_credentials: dict[str, Any] = {} + new_credentials["confluence_access_token"] = token_response.access_token + new_credentials["confluence_refresh_token"] = token_response.refresh_token + new_credentials["created_at"] = now.isoformat() + new_credentials["expires_at"] = expires_at.isoformat() + new_credentials["expires_in"] = token_response.expires_in + new_credentials["scope"] = token_response.scope + new_credentials["cloud_id"] = cloud_id + return new_credentials + + +F = TypeVar("F", bound=Callable[..., Any]) + + +# https://developer.atlassian.com/cloud/confluence/rate-limiting/ +# this uses the native rate limiting option provided by the +# confluence client and otherwise applies a simpler set of error handling +def handle_confluence_rate_limit(confluence_call: F) -> F: + def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: + MAX_RETRIES = 5 + + TIMEOUT = 600 + timeout_at = time.monotonic() + TIMEOUT + + for attempt in range(MAX_RETRIES): + if time.monotonic() > timeout_at: + raise TimeoutError( + f"Confluence call attempts took longer than {TIMEOUT} seconds." + ) + + try: + # we're relying more on the client to rate limit itself + # and applying our own retries in a more specific set of circumstances + return confluence_call(*args, **kwargs) + except requests.HTTPError as e: + delay_until = _handle_http_error(e, attempt) + logger.warning( + f"HTTPError in confluence call. " + f"Retrying in {delay_until} seconds..." + ) + while time.monotonic() < delay_until: + # in the future, check a signal here to exit + time.sleep(1) + except AttributeError as e: + # Some error within the Confluence library, unclear why it fails. + # Users reported it to be intermittent, so just retry + if attempt == MAX_RETRIES - 1: + raise e + + logger.exception( + "Confluence Client raised an AttributeError. Retrying..." + ) + time.sleep(5) + + return cast(F, wrapped_call) + + +def _handle_http_error(e: requests.HTTPError, attempt: int) -> int: + MIN_DELAY = 2 + MAX_DELAY = 60 + STARTING_DELAY = 5 + BACKOFF = 2 + + # Check if the response or headers are None to avoid potential AttributeError + if e.response is None or e.response.headers is None: + logger.warning("HTTPError with `None` as response or as headers") + raise e + + if ( + e.response.status_code != 429 + and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower() + ): + raise e + + retry_after = None + + retry_after_header = e.response.headers.get("Retry-After") + if retry_after_header is not None: + try: + retry_after = int(retry_after_header) + if retry_after > MAX_DELAY: + logger.warning( + f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..." + ) + retry_after = MAX_DELAY + if retry_after < MIN_DELAY: + retry_after = MIN_DELAY + except ValueError: + pass + + if retry_after is not None: + logger.warning( + f"Rate limiting with retry header. Retrying after {retry_after} seconds..." + ) + delay = retry_after + else: + logger.warning( + "Rate limiting without retry header. Retrying with exponential backoff..." + ) + delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY) + + delay_until = math.ceil(time.monotonic() + delay) + return delay_until + + def get_single_param_from_url(url: str, param: str) -> str | None: """Get a parameter from a url""" parsed_url = urlparse(url) diff --git a/backend/onyx/connectors/credentials_provider.py b/backend/onyx/connectors/credentials_provider.py new file mode 100644 index 000000000..2ba7e6a6f --- /dev/null +++ b/backend/onyx/connectors/credentials_provider.py @@ -0,0 +1,135 @@ +import uuid +from types import TracebackType +from typing import Any + +from redis.lock import Lock as RedisLock +from sqlalchemy import select + +from onyx.connectors.interfaces import CredentialsProviderInterface +from onyx.db.engine import get_session_with_tenant +from onyx.db.models import Credential +from onyx.redis.redis_pool import get_redis_client + + +class OnyxDBCredentialsProvider( + CredentialsProviderInterface["OnyxDBCredentialsProvider"] +): + """Implementation to allow the connector to callback and update credentials in the db. + Required in cases where credentials can rotate while the connector is running. + """ + + LOCK_TTL = 900 # TTL of the lock + + def __init__(self, tenant_id: str, connector_name: str, credential_id: int): + self._tenant_id = tenant_id + self._connector_name = connector_name + self._credential_id = credential_id + + self.redis_client = get_redis_client(tenant_id=tenant_id) + + # lock used to prevent overlapping renewal of credentials + self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}" + self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL) + + def __enter__(self) -> "OnyxDBCredentialsProvider": + acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL) + if not acquired: + raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}") + + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + """Release the lock when exiting the context.""" + if self._lock and self._lock.owned(): + self._lock.release() + + def get_tenant_id(self) -> str | None: + return self._tenant_id + + def get_provider_key(self) -> str: + return str(self._credential_id) + + def get_credentials(self) -> dict[str, Any]: + with get_session_with_tenant(tenant_id=self._tenant_id) as db_session: + credential = db_session.execute( + select(Credential).where(Credential.id == self._credential_id) + ).scalar_one() + + if credential is None: + raise ValueError( + f"No credential found: credential={self._credential_id}" + ) + + return credential.credential_json + + def set_credentials(self, credential_json: dict[str, Any]) -> None: + with get_session_with_tenant(tenant_id=self._tenant_id) as db_session: + try: + credential = db_session.execute( + select(Credential) + .where(Credential.id == self._credential_id) + .with_for_update() + ).scalar_one() + + if credential is None: + raise ValueError( + f"No credential found: credential={self._credential_id}" + ) + + credential.credential_json = credential_json + db_session.commit() + except Exception: + db_session.rollback() + raise + + def is_dynamic(self) -> bool: + return True + + +class OnyxStaticCredentialsProvider( + CredentialsProviderInterface["OnyxStaticCredentialsProvider"] +): + """Implementation (a very simple one!) to handle static credentials.""" + + def __init__( + self, + tenant_id: str | None, + connector_name: str, + credential_json: dict[str, Any], + ): + self._tenant_id = tenant_id + self._connector_name = connector_name + self._credential_json = credential_json + + self._provider_key = str(uuid.uuid4()) + + def __enter__(self) -> "OnyxStaticCredentialsProvider": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + pass + + def get_tenant_id(self) -> str | None: + return self._tenant_id + + def get_provider_key(self) -> str: + return self._provider_key + + def get_credentials(self) -> dict[str, Any]: + return self._credential_json + + def set_credentials(self, credential_json: dict[str, Any]) -> None: + self._credential_json = credential_json + + def is_dynamic(self) -> bool: + return False diff --git a/backend/onyx/connectors/factory.py b/backend/onyx/connectors/factory.py index 14221d2e3..73593cc60 100644 --- a/backend/onyx/connectors/factory.py +++ b/backend/onyx/connectors/factory.py @@ -12,6 +12,7 @@ from onyx.connectors.blob.connector import BlobStorageConnector from onyx.connectors.bookstack.connector import BookstackConnector from onyx.connectors.clickup.connector import ClickupConnector from onyx.connectors.confluence.connector import ConfluenceConnector +from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.connectors.discord.connector import DiscordConnector from onyx.connectors.discourse.connector import DiscourseConnector from onyx.connectors.document360.connector import Document360Connector @@ -32,6 +33,7 @@ from onyx.connectors.guru.connector import GuruConnector from onyx.connectors.hubspot.connector import HubSpotConnector from onyx.connectors.interfaces import BaseConnector from onyx.connectors.interfaces import CheckpointConnector +from onyx.connectors.interfaces import CredentialsConnector from onyx.connectors.interfaces import EventConnector from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector @@ -57,6 +59,7 @@ from onyx.db.connector import fetch_connector_by_id from onyx.db.credentials import backend_update_credential_json from onyx.db.credentials import fetch_credential_by_id from onyx.db.models import Credential +from shared_configs.contextvars import get_current_tenant_id class ConnectorMissingException(Exception): @@ -167,10 +170,17 @@ def instantiate_connector( connector_class = identify_connector_class(source, input_type) connector = connector_class(**connector_specific_config) - new_credentials = connector.load_credentials(credential.credential_json) - if new_credentials is not None: - backend_update_credential_json(credential, new_credentials, db_session) + if isinstance(connector, CredentialsConnector): + provider = OnyxDBCredentialsProvider( + get_current_tenant_id(), str(source), credential.id + ) + connector.set_credentials_provider(provider) + else: + new_credentials = connector.load_credentials(credential.credential_json) + + if new_credentials is not None: + backend_update_credential_json(credential, new_credentials, db_session) return connector diff --git a/backend/onyx/connectors/interfaces.py b/backend/onyx/connectors/interfaces.py index 8516d08a3..0b2f8b661 100644 --- a/backend/onyx/connectors/interfaces.py +++ b/backend/onyx/connectors/interfaces.py @@ -1,7 +1,10 @@ import abc from collections.abc import Generator from collections.abc import Iterator +from types import TracebackType from typing import Any +from typing import Generic +from typing import TypeVar from pydantic import BaseModel @@ -111,6 +114,69 @@ class OAuthConnector(BaseConnector): raise NotImplementedError +T = TypeVar("T", bound="CredentialsProviderInterface") + + +class CredentialsProviderInterface(abc.ABC, Generic[T]): + @abc.abstractmethod + def __enter__(self) -> T: + raise NotImplementedError + + @abc.abstractmethod + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + raise NotImplementedError + + @abc.abstractmethod + def get_tenant_id(self) -> str | None: + raise NotImplementedError + + @abc.abstractmethod + def get_provider_key(self) -> str: + """a unique key that the connector can use to lock around a credential + that might be used simultaneously. + + Will typically be the credential id, but can also just be something random + in cases when there is nothing to lock (aka static credentials) + """ + raise NotImplementedError + + @abc.abstractmethod + def get_credentials(self) -> dict[str, Any]: + raise NotImplementedError + + @abc.abstractmethod + def set_credentials(self, credential_json: dict[str, Any]) -> None: + raise NotImplementedError + + @abc.abstractmethod + def is_dynamic(self) -> bool: + """If dynamic, the credentials may change during usage ... maening the client + needs to use the locking features of the credentials provider to operate + correctly. + + If static, the client can simply reference the credentials once and use them + through the entire indexing run. + """ + raise NotImplementedError + + +class CredentialsConnector(BaseConnector): + """Implement this if the connector needs to be able to read and write credentials + on the fly. Typically used with shared credentials/tokens that might be renewed + at any time.""" + + @abc.abstractmethod + def set_credentials_provider( + self, credentials_provider: CredentialsProviderInterface + ) -> None: + raise NotImplementedError + + # Event driven class EventConnector(BaseConnector): @abc.abstractmethod diff --git a/backend/onyx/connectors/slack/utils.py b/backend/onyx/connectors/slack/utils.py index 8428a4534..757036a9f 100644 --- a/backend/onyx/connectors/slack/utils.py +++ b/backend/onyx/connectors/slack/utils.py @@ -72,6 +72,7 @@ def make_slack_api_rate_limited( @wraps(call) def rate_limited_call(**kwargs: Any) -> SlackResponse: last_exception = None + for _ in range(max_retries): try: # Make the API call diff --git a/backend/onyx/connectors/web/connector.py b/backend/onyx/connectors/web/connector.py index 9bd6e2073..9380791db 100644 --- a/backend/onyx/connectors/web/connector.py +++ b/backend/onyx/connectors/web/connector.py @@ -42,6 +42,10 @@ from shared_configs.configs import MULTI_TENANT logger = setup_logger() WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20 +# Threshold for determining when to replace vs append iframe content +IFRAME_TEXT_LENGTH_THRESHOLD = 700 +# Message indicating JavaScript is disabled, which often appears when scraping fails +JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser" class WEB_CONNECTOR_VALID_SETTINGS(str, Enum): @@ -138,7 +142,8 @@ def get_internal_links( # Account for malformed backslashes in URLs href = href.replace("\\", "/") - if should_ignore_pound and "#" in href: + # "#!" indicates the page is using a hashbang URL, which is a client-side routing technique + if should_ignore_pound and "#" in href and "#!" not in href: href = href.split("#")[0] if not is_valid_url(href): @@ -288,6 +293,7 @@ class WebConnector(LoadConnector): and converts them into documents""" visited_links: set[str] = set() to_visit: list[str] = self.to_visit_list + content_hashes = set() if not to_visit: raise ValueError("No URLs to visit") @@ -302,29 +308,30 @@ class WebConnector(LoadConnector): playwright, context = start_playwright() restart_playwright = False while to_visit: - current_url = to_visit.pop() - if current_url in visited_links: + initial_url = to_visit.pop() + if initial_url in visited_links: continue - visited_links.add(current_url) + visited_links.add(initial_url) try: - protected_url_check(current_url) + protected_url_check(initial_url) except Exception as e: - last_error = f"Invalid URL {current_url} due to {e}" + last_error = f"Invalid URL {initial_url} due to {e}" logger.warning(last_error) continue - logger.info(f"Visiting {current_url}") + index = len(visited_links) + logger.info(f"{index}: Visiting {initial_url}") try: - check_internet_connection(current_url) + check_internet_connection(initial_url) if restart_playwright: playwright, context = start_playwright() restart_playwright = False - if current_url.split(".")[-1] == "pdf": + if initial_url.split(".")[-1] == "pdf": # PDF files are not checked for links - response = requests.get(current_url) + response = requests.get(initial_url) page_text, metadata = read_pdf_file( file=io.BytesIO(response.content) ) @@ -332,10 +339,10 @@ class WebConnector(LoadConnector): doc_batch.append( Document( - id=current_url, - sections=[Section(link=current_url, text=page_text)], + id=initial_url, + sections=[Section(link=initial_url, text=page_text)], source=DocumentSource.WEB, - semantic_identifier=current_url.split("/")[-1], + semantic_identifier=initial_url.split("/")[-1], metadata=metadata, doc_updated_at=_get_datetime_from_last_modified_header( last_modified @@ -347,21 +354,29 @@ class WebConnector(LoadConnector): continue page = context.new_page() - page_response = page.goto(current_url) + + # Can't use wait_until="networkidle" because it interferes with the scrolling behavior + page_response = page.goto( + initial_url, + timeout=30000, # 30 seconds + ) + last_modified = ( page_response.header_value("Last-Modified") if page_response else None ) - final_page = page.url - if final_page != current_url: - logger.info(f"Redirected to {final_page}") - protected_url_check(final_page) - current_url = final_page - if current_url in visited_links: - logger.info("Redirected page already indexed") + final_url = page.url + if final_url != initial_url: + protected_url_check(final_url) + initial_url = final_url + if initial_url in visited_links: + logger.info( + f"{index}: {initial_url} redirected to {final_url} - already indexed" + ) continue - visited_links.add(current_url) + logger.info(f"{index}: {initial_url} redirected to {final_url}") + visited_links.add(initial_url) if self.scroll_before_scraping: scroll_attempts = 0 @@ -379,26 +394,58 @@ class WebConnector(LoadConnector): soup = BeautifulSoup(content, "html.parser") if self.recursive: - internal_links = get_internal_links(base_url, current_url, soup) + internal_links = get_internal_links(base_url, initial_url, soup) for link in internal_links: if link not in visited_links: to_visit.append(link) if page_response and str(page_response.status)[0] in ("4", "5"): - last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response" + last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response" logger.info(last_error) continue parsed_html = web_html_cleanup(soup, self.mintlify_cleanup) + """For websites containing iframes that need to be scraped, + the code below can extract text from within these iframes. + """ + logger.debug( + f"{index}: Length of cleaned text {len(parsed_html.cleaned_text)}" + ) + if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text: + iframe_count = page.frame_locator("iframe").locator("html").count() + if iframe_count > 0: + iframe_texts = ( + page.frame_locator("iframe") + .locator("html") + .all_inner_texts() + ) + document_text = "\n".join(iframe_texts) + """ 700 is the threshold value for the length of the text extracted + from the iframe based on the issue faced """ + if len(parsed_html.cleaned_text) < IFRAME_TEXT_LENGTH_THRESHOLD: + parsed_html.cleaned_text = document_text + else: + parsed_html.cleaned_text += "\n" + document_text + + # Sometimes pages with #! will serve duplicate content + # There are also just other ways this can happen + hashed_text = hash((parsed_html.title, parsed_html.cleaned_text)) + if hashed_text in content_hashes: + logger.info( + f"{index}: Skipping duplicate title + content for {initial_url}" + ) + continue + content_hashes.add(hashed_text) + doc_batch.append( Document( - id=current_url, + id=initial_url, sections=[ - Section(link=current_url, text=parsed_html.cleaned_text) + Section(link=initial_url, text=parsed_html.cleaned_text) ], source=DocumentSource.WEB, - semantic_identifier=parsed_html.title or current_url, + semantic_identifier=parsed_html.title or initial_url, metadata={}, doc_updated_at=_get_datetime_from_last_modified_header( last_modified @@ -410,7 +457,7 @@ class WebConnector(LoadConnector): page.close() except Exception as e: - last_error = f"Failed to fetch '{current_url}': {e}" + last_error = f"Failed to fetch '{initial_url}': {e}" logger.exception(last_error) playwright.stop() restart_playwright = True diff --git a/backend/onyx/context/search/models.py b/backend/onyx/context/search/models.py index 7eeb35686..3d19db186 100644 --- a/backend/onyx/context/search/models.py +++ b/backend/onyx/context/search/models.py @@ -76,6 +76,10 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting): provider_type=search_settings.provider_type, index_name=search_settings.index_name, multipass_indexing=search_settings.multipass_indexing, + embedding_precision=search_settings.embedding_precision, + reduced_dimension=search_settings.reduced_dimension, + # Whether switching to this model requires re-indexing + background_reindex_enabled=search_settings.background_reindex_enabled, # Reranking Details rerank_model_name=search_settings.rerank_model_name, rerank_provider_type=search_settings.rerank_provider_type, diff --git a/backend/onyx/db/chat.py b/backend/onyx/db/chat.py index 0da54b8ad..5335d25b8 100644 --- a/backend/onyx/db/chat.py +++ b/backend/onyx/db/chat.py @@ -168,7 +168,7 @@ def get_chat_sessions_by_user( if not include_onyxbot_flows: stmt = stmt.where(ChatSession.onyxbot_flow.is_(False)) - stmt = stmt.order_by(desc(ChatSession.time_created)) + stmt = stmt.order_by(desc(ChatSession.time_updated)) if deleted is not None: stmt = stmt.where(ChatSession.deleted == deleted) @@ -962,6 +962,7 @@ def translate_db_message_to_chat_message_detail( chat_message.sub_questions ), refined_answer_improvement=chat_message.refined_answer_improvement, + is_agentic=chat_message.is_agentic, error=chat_message.error, ) diff --git a/backend/onyx/db/credentials.py b/backend/onyx/db/credentials.py index 40edbcef3..cbd578ff2 100644 --- a/backend/onyx/db/credentials.py +++ b/backend/onyx/db/credentials.py @@ -360,18 +360,13 @@ def backend_update_credential_json( db_session.commit() -def delete_credential( +def _delete_credential_internal( + credential: Credential, credential_id: int, - user: User | None, db_session: Session, force: bool = False, ) -> None: - credential = fetch_credential_by_id_for_user(credential_id, user, db_session) - if credential is None: - raise ValueError( - f"Credential by provided id {credential_id} does not exist or does not belong to user" - ) - + """Internal utility function to handle the actual deletion of a credential""" associated_connectors = ( db_session.query(ConnectorCredentialPair) .filter(ConnectorCredentialPair.credential_id == credential_id) @@ -416,6 +411,35 @@ def delete_credential( db_session.commit() +def delete_credential_for_user( + credential_id: int, + user: User, + db_session: Session, + force: bool = False, +) -> None: + """Delete a credential that belongs to a specific user""" + credential = fetch_credential_by_id_for_user(credential_id, user, db_session) + if credential is None: + raise ValueError( + f"Credential by provided id {credential_id} does not exist or does not belong to user" + ) + + _delete_credential_internal(credential, credential_id, db_session, force) + + +def delete_credential( + credential_id: int, + db_session: Session, + force: bool = False, +) -> None: + """Delete a credential regardless of ownership (admin function)""" + credential = fetch_credential_by_id(credential_id, db_session) + if credential is None: + raise ValueError(f"Credential by provided id {credential_id} does not exist") + + _delete_credential_internal(credential, credential_id, db_session, force) + + def create_initial_public_credential(db_session: Session) -> None: error_msg = ( "DB is not in a valid initial state." diff --git a/backend/onyx/db/enums.py b/backend/onyx/db/enums.py index 329ee35fc..c5a3ced2f 100644 --- a/backend/onyx/db/enums.py +++ b/backend/onyx/db/enums.py @@ -63,6 +63,9 @@ class IndexModelStatus(str, PyEnum): PRESENT = "PRESENT" FUTURE = "FUTURE" + def is_current(self) -> bool: + return self == IndexModelStatus.PRESENT + class ChatSessionSharedStatus(str, PyEnum): PUBLIC = "public" @@ -83,3 +86,11 @@ class AccessType(str, PyEnum): PUBLIC = "public" PRIVATE = "private" SYNC = "sync" + + +class EmbeddingPrecision(str, PyEnum): + # matches vespa tensor type + # only support float / bfloat16 for now, since there's not a + # good reason to specify anything else + BFLOAT16 = "bfloat16" + FLOAT = "float" diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 0001ec318..484d24620 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -7,6 +7,7 @@ from typing import Optional from uuid import uuid4 from pydantic import BaseModel +from sqlalchemy.orm import validates from typing_extensions import TypedDict # noreorder from uuid import UUID @@ -45,7 +46,13 @@ from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType from onyx.configs.constants import DocumentSource from onyx.configs.constants import FileOrigin from onyx.configs.constants import MessageType -from onyx.db.enums import AccessType, IndexingMode, SyncType, SyncStatus +from onyx.db.enums import ( + AccessType, + EmbeddingPrecision, + IndexingMode, + SyncType, + SyncStatus, +) from onyx.configs.constants import NotificationType from onyx.configs.constants import SearchFeedbackType from onyx.configs.constants import TokenRateLimitScope @@ -206,6 +213,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base): primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)", ) + @validates("email") + def validate_email(self, key: str, value: str) -> str: + return value.lower() if value else value + @property def password_configured(self) -> bool: """ @@ -711,6 +722,23 @@ class SearchSettings(Base): ForeignKey("embedding_provider.provider_type"), nullable=True ) + # Whether switching to this model should re-index all connectors in the background + # if no re-index is needed, will be ignored. Only used during the switch-over process. + background_reindex_enabled: Mapped[bool] = mapped_column(Boolean, default=True) + + # allows for quantization -> less memory usage for a small performance hit + embedding_precision: Mapped[EmbeddingPrecision] = mapped_column( + Enum(EmbeddingPrecision, native_enum=False) + ) + + # can be used to reduce dimensionality of vectors and save memory with + # a small performance hit. More details in the `Reducing embedding dimensions` + # section here: + # https://platform.openai.com/docs/guides/embeddings#embedding-models + # If not specified, will just use the model_dim without any reduction. + # NOTE: this is only currently available for OpenAI models + reduced_dimension: Mapped[int | None] = mapped_column(Integer, nullable=True) + # Mini and Large Chunks (large chunk also checks for model max context) multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True) @@ -792,6 +820,12 @@ class SearchSettings(Base): self.multipass_indexing, self.model_name, self.provider_type ) + @property + def final_embedding_dim(self) -> int: + if self.reduced_dimension: + return self.reduced_dimension + return self.model_dim + @staticmethod def can_use_large_chunks( multipass: bool, model_name: str, provider_type: EmbeddingProvider | None @@ -1756,6 +1790,7 @@ class ChannelConfig(TypedDict): channel_name: str | None # None for default channel config respond_tag_only: NotRequired[bool] # defaults to False respond_to_bots: NotRequired[bool] # defaults to False + is_ephemeral: NotRequired[bool] # defaults to False respond_member_group_list: NotRequired[list[str]] answer_filters: NotRequired[list[AllowedAnswerFilters]] # If None then no follow up @@ -2270,6 +2305,10 @@ class UserTenantMapping(Base): email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True) tenant_id: Mapped[str] = mapped_column(String, nullable=False) + @validates("email") + def validate_email(self, key: str, value: str) -> str: + return value.lower() if value else value + # This is a mapping from tenant IDs to anonymous user paths class TenantAnonymousUserPath(Base): diff --git a/backend/onyx/db/persona.py b/backend/onyx/db/persona.py index b879e426c..ae37b3f50 100644 --- a/backend/onyx/db/persona.py +++ b/backend/onyx/db/persona.py @@ -209,13 +209,21 @@ def create_update_persona( if not all_prompt_ids: raise ValueError("No prompt IDs provided") + is_default_persona: bool | None = create_persona_request.is_default_persona # Default persona validation if create_persona_request.is_default_persona: if not create_persona_request.is_public: raise ValueError("Cannot make a default persona non public") - if user and user.role != UserRole.ADMIN: - raise ValueError("Only admins can make a default persona") + if user: + # Curators can edit default personas, but not make them + if ( + user.role == UserRole.CURATOR + or user.role == UserRole.GLOBAL_CURATOR + ): + is_default_persona = None + elif user.role != UserRole.ADMIN: + raise ValueError("Only admins can make a default persona") persona = upsert_persona( persona_id=persona_id, @@ -241,7 +249,7 @@ def create_update_persona( num_chunks=create_persona_request.num_chunks, llm_relevance_filter=create_persona_request.llm_relevance_filter, llm_filter_extraction=create_persona_request.llm_filter_extraction, - is_default_persona=create_persona_request.is_default_persona, + is_default_persona=is_default_persona, ) versioned_make_persona_private = fetch_versioned_implementation( @@ -428,7 +436,7 @@ def upsert_persona( remove_image: bool | None = None, search_start_date: datetime | None = None, builtin_persona: bool = False, - is_default_persona: bool = False, + is_default_persona: bool | None = None, label_ids: list[int] | None = None, chunks_above: int = CONTEXT_CHUNKS_ABOVE, chunks_below: int = CONTEXT_CHUNKS_BELOW, @@ -523,7 +531,11 @@ def upsert_persona( existing_persona.is_visible = is_visible existing_persona.search_start_date = search_start_date existing_persona.labels = labels or [] - existing_persona.is_default_persona = is_default_persona + existing_persona.is_default_persona = ( + is_default_persona + if is_default_persona is not None + else existing_persona.is_default_persona + ) # Do not delete any associations manually added unless # a new updated list is provided if document_sets is not None: @@ -575,7 +587,9 @@ def upsert_persona( display_priority=display_priority, is_visible=is_visible, search_start_date=search_start_date, - is_default_persona=is_default_persona, + is_default_persona=is_default_persona + if is_default_persona is not None + else False, labels=labels or [], ) db_session.add(new_persona) diff --git a/backend/onyx/db/search_settings.py b/backend/onyx/db/search_settings.py index 8ac120b05..371e6f1dd 100644 --- a/backend/onyx/db/search_settings.py +++ b/backend/onyx/db/search_settings.py @@ -13,6 +13,7 @@ from onyx.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS from onyx.context.search.models import SavedSearchSettings +from onyx.db.enums import EmbeddingPrecision from onyx.db.llm import fetch_embedding_provider from onyx.db.models import CloudEmbeddingProvider from onyx.db.models import IndexAttempt @@ -59,12 +60,15 @@ def create_search_settings( index_name=search_settings.index_name, provider_type=search_settings.provider_type, multipass_indexing=search_settings.multipass_indexing, + embedding_precision=search_settings.embedding_precision, + reduced_dimension=search_settings.reduced_dimension, multilingual_expansion=search_settings.multilingual_expansion, disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming, rerank_model_name=search_settings.rerank_model_name, rerank_provider_type=search_settings.rerank_provider_type, rerank_api_key=search_settings.rerank_api_key, num_rerank=search_settings.num_rerank, + background_reindex_enabled=search_settings.background_reindex_enabled, ) db_session.add(embedding_model) @@ -305,6 +309,7 @@ def get_old_default_embedding_model() -> IndexingSetting: model_dim=( DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM ), + embedding_precision=(EmbeddingPrecision.FLOAT), normalize=( NORMALIZE_EMBEDDINGS if is_overridden @@ -322,6 +327,7 @@ def get_new_default_embedding_model() -> IndexingSetting: return IndexingSetting( model_name=DOCUMENT_ENCODER_MODEL, model_dim=DOC_EMBEDDING_DIM, + embedding_precision=(EmbeddingPrecision.FLOAT), normalize=NORMALIZE_EMBEDDINGS, query_prefix=ASYM_QUERY_PREFIX, passage_prefix=ASYM_PASSAGE_PREFIX, diff --git a/backend/onyx/db/swap_index.py b/backend/onyx/db/swap_index.py index abe7bdaf5..6e5472ea9 100644 --- a/backend/onyx/db/swap_index.py +++ b/backend/onyx/db/swap_index.py @@ -8,10 +8,12 @@ from onyx.db.index_attempt import cancel_indexing_attempts_past_model from onyx.db.index_attempt import ( count_unique_cc_pairs_with_successful_index_attempts, ) +from onyx.db.models import ConnectorCredentialPair from onyx.db.models import SearchSettings from onyx.db.search_settings import get_current_search_settings from onyx.db.search_settings import get_secondary_search_settings from onyx.db.search_settings import update_search_settings_status +from onyx.document_index.factory import get_default_document_index from onyx.key_value_store.factory import get_kv_store from onyx.utils.logger import setup_logger @@ -19,7 +21,49 @@ from onyx.utils.logger import setup_logger logger = setup_logger() -def check_index_swap(db_session: Session) -> SearchSettings | None: +def _perform_index_swap( + db_session: Session, + current_search_settings: SearchSettings, + secondary_search_settings: SearchSettings, + all_cc_pairs: list[ConnectorCredentialPair], +) -> None: + """Swap the indices and expire the old one.""" + current_search_settings = get_current_search_settings(db_session) + update_search_settings_status( + search_settings=current_search_settings, + new_status=IndexModelStatus.PAST, + db_session=db_session, + ) + + update_search_settings_status( + search_settings=secondary_search_settings, + new_status=IndexModelStatus.PRESENT, + db_session=db_session, + ) + + if len(all_cc_pairs) > 0: + kv_store = get_kv_store() + kv_store.store(KV_REINDEX_KEY, False) + + # Expire jobs for the now past index/embedding model + cancel_indexing_attempts_past_model(db_session) + + # Recount aggregates + for cc_pair in all_cc_pairs: + resync_cc_pair(cc_pair, db_session=db_session) + + # remove the old index from the vector db + document_index = get_default_document_index(secondary_search_settings, None) + document_index.ensure_indices_exist( + primary_embedding_dim=secondary_search_settings.final_embedding_dim, + primary_embedding_precision=secondary_search_settings.embedding_precision, + # just finished swap, no more secondary index + secondary_index_embedding_dim=None, + secondary_index_embedding_precision=None, + ) + + +def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None: """Get count of cc-pairs and count of successful index_attempts for the new model grouped by connector + credential, if it's the same, then assume new index is done building. If so, swap the indices and expire the old one. @@ -27,52 +71,45 @@ def check_index_swap(db_session: Session) -> SearchSettings | None: Returns None if search settings did not change, or the old search settings if they did change. """ - - old_search_settings = None - # Default CC-pair created for Ingestion API unused here all_cc_pairs = get_connector_credential_pairs(db_session) cc_pair_count = max(len(all_cc_pairs) - 1, 0) - search_settings = get_secondary_search_settings(db_session) + secondary_search_settings = get_secondary_search_settings(db_session) - if not search_settings: + if not secondary_search_settings: return None + # If the secondary search settings are not configured to reindex in the background, + # we can just swap over instantly + if not secondary_search_settings.background_reindex_enabled: + current_search_settings = get_current_search_settings(db_session) + _perform_index_swap( + db_session=db_session, + current_search_settings=current_search_settings, + secondary_search_settings=secondary_search_settings, + all_cc_pairs=all_cc_pairs, + ) + return current_search_settings + unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts( - search_settings_id=search_settings.id, db_session=db_session + search_settings_id=secondary_search_settings.id, db_session=db_session ) # Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this # function is correct. The unique_cc_indexings are specifically for the existing cc-pairs + old_search_settings = None if unique_cc_indexings > cc_pair_count: logger.error("More unique indexings than cc pairs, should not occur") if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings: # Swap indices current_search_settings = get_current_search_settings(db_session) - update_search_settings_status( - search_settings=current_search_settings, - new_status=IndexModelStatus.PAST, + _perform_index_swap( db_session=db_session, + current_search_settings=current_search_settings, + secondary_search_settings=secondary_search_settings, + all_cc_pairs=all_cc_pairs, ) - - update_search_settings_status( - search_settings=search_settings, - new_status=IndexModelStatus.PRESENT, - db_session=db_session, - ) - - if cc_pair_count > 0: - kv_store = get_kv_store() - kv_store.store(KV_REINDEX_KEY, False) - - # Expire jobs for the now past index/embedding model - cancel_indexing_attempts_past_model(db_session) - - # Recount aggregates - for cc_pair in all_cc_pairs: - resync_cc_pair(cc_pair, db_session=db_session) - - old_search_settings = current_search_settings + old_search_settings = current_search_settings return old_search_settings diff --git a/backend/onyx/document_index/interfaces.py b/backend/onyx/document_index/interfaces.py index 663e5feee..463abbc95 100644 --- a/backend/onyx/document_index/interfaces.py +++ b/backend/onyx/document_index/interfaces.py @@ -6,6 +6,7 @@ from typing import Any from onyx.access.models import DocumentAccess from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunkUncleaned +from onyx.db.enums import EmbeddingPrecision from onyx.indexing.models import DocMetadataAwareIndexChunk from shared_configs.model_server_models import Embedding @@ -145,17 +146,21 @@ class Verifiable(abc.ABC): @abc.abstractmethod def ensure_indices_exist( self, - index_embedding_dim: int, + primary_embedding_dim: int, + primary_embedding_precision: EmbeddingPrecision, secondary_index_embedding_dim: int | None, + secondary_index_embedding_precision: EmbeddingPrecision | None, ) -> None: """ Verify that the document index exists and is consistent with the expectations in the code. Parameters: - - index_embedding_dim: Vector dimensionality for the vector similarity part of the search + - primary_embedding_dim: Vector dimensionality for the vector similarity part of the search + - primary_embedding_precision: Precision of the vector similarity part of the search - secondary_index_embedding_dim: Vector dimensionality of the secondary index being built behind the scenes. The secondary index should only be built when switching embedding models therefore this dim should be different from the primary index. + - secondary_index_embedding_precision: Precision of the vector similarity part of the secondary index """ raise NotImplementedError @@ -164,6 +169,7 @@ class Verifiable(abc.ABC): def register_multitenant_indices( indices: list[str], embedding_dims: list[int], + embedding_precisions: list[EmbeddingPrecision], ) -> None: """ Register multitenant indices with the document index. diff --git a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd index 2fd861b77..f846c32fc 100644 --- a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd @@ -37,7 +37,7 @@ schema DANSWER_CHUNK_NAME { summary: dynamic } # Title embedding (x1) - field title_embedding type tensor(x[VARIABLE_DIM]) { + field title_embedding type tensor(x[VARIABLE_DIM]) { indexing: attribute | index attribute { distance-metric: angular @@ -45,7 +45,7 @@ schema DANSWER_CHUNK_NAME { } # Content embeddings (chunk + optional mini chunks embeddings) # "t" and "x" are arbitrary names, not special keywords - field embeddings type tensor(t{},x[VARIABLE_DIM]) { + field embeddings type tensor(t{},x[VARIABLE_DIM]) { indexing: attribute | index attribute { distance-metric: angular diff --git a/backend/onyx/document_index/vespa/app_config/validation-overrides.xml b/backend/onyx/document_index/vespa/app_config/validation-overrides.xml index c5d1598bf..7b0709620 100644 --- a/backend/onyx/document_index/vespa/app_config/validation-overrides.xml +++ b/backend/onyx/document_index/vespa/app_config/validation-overrides.xml @@ -5,4 +5,7 @@ indexing-change + field-type-change diff --git a/backend/onyx/document_index/vespa/chunk_retrieval.py b/backend/onyx/document_index/vespa/chunk_retrieval.py index 37225b452..5f3dff5c8 100644 --- a/backend/onyx/document_index/vespa/chunk_retrieval.py +++ b/backend/onyx/document_index/vespa/chunk_retrieval.py @@ -310,6 +310,11 @@ def query_vespa( f"Request Headers: {e.request.headers}\n" f"Request Payload: {params}\n" f"Exception: {str(e)}" + + ( + f"\nResponse: {e.response.text}" + if isinstance(e, httpx.HTTPStatusError) + else "" + ) ) raise httpx.HTTPError(error_base) from e diff --git a/backend/onyx/document_index/vespa/index.py b/backend/onyx/document_index/vespa/index.py index c2e631f6c..17aadb36c 100644 --- a/backend/onyx/document_index/vespa/index.py +++ b/backend/onyx/document_index/vespa/index.py @@ -26,6 +26,7 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS from onyx.configs.constants import KV_REINDEX_KEY from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunkUncleaned +from onyx.db.enums import EmbeddingPrecision from onyx.document_index.document_index_utils import get_document_chunk_ids from onyx.document_index.interfaces import DocumentIndex from onyx.document_index.interfaces import DocumentInsertionRecord @@ -63,6 +64,7 @@ from onyx.document_index.vespa_constants import DATE_REPLACEMENT from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT from onyx.document_index.vespa_constants import DOCUMENT_SETS +from onyx.document_index.vespa_constants import EMBEDDING_PRECISION_REPLACEMENT_PAT from onyx.document_index.vespa_constants import HIDDEN from onyx.document_index.vespa_constants import NUM_THREADS from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT @@ -112,6 +114,21 @@ def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str: return "\n".join(doc_lines) +def _replace_template_values_in_schema( + schema_template: str, + index_name: str, + embedding_dim: int, + embedding_precision: EmbeddingPrecision, +) -> str: + return ( + schema_template.replace( + EMBEDDING_PRECISION_REPLACEMENT_PAT, embedding_precision.value + ) + .replace(DANSWER_CHUNK_REPLACEMENT_PAT, index_name) + .replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim)) + ) + + def add_ngrams_to_schema(schema_content: str) -> str: # Add the match blocks containing gram and gram-size to title and content fields schema_content = re.sub( @@ -163,8 +180,10 @@ class VespaIndex(DocumentIndex): def ensure_indices_exist( self, - index_embedding_dim: int, + primary_embedding_dim: int, + primary_embedding_precision: EmbeddingPrecision, secondary_index_embedding_dim: int | None, + secondary_index_embedding_precision: EmbeddingPrecision | None, ) -> None: if MULTI_TENANT: logger.info( @@ -221,18 +240,29 @@ class VespaIndex(DocumentIndex): schema_template = schema_f.read() schema_template = schema_template.replace(TENANT_ID_PAT, "") - schema = schema_template.replace( - DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name - ).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim)) + schema = _replace_template_values_in_schema( + schema_template, + self.index_name, + primary_embedding_dim, + primary_embedding_precision, + ) schema = add_ngrams_to_schema(schema) if needs_reindexing else schema schema = schema.replace(TENANT_ID_PAT, "") zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8") if self.secondary_index_name: - upcoming_schema = schema_template.replace( - DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name - ).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim)) + if secondary_index_embedding_dim is None: + raise ValueError("Secondary index embedding dimension is required") + if secondary_index_embedding_precision is None: + raise ValueError("Secondary index embedding precision is required") + + upcoming_schema = _replace_template_values_in_schema( + schema_template, + self.secondary_index_name, + secondary_index_embedding_dim, + secondary_index_embedding_precision, + ) zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8") zip_file = in_memory_zip_from_file_bytes(zip_dict) @@ -251,6 +281,7 @@ class VespaIndex(DocumentIndex): def register_multitenant_indices( indices: list[str], embedding_dims: list[int], + embedding_precisions: list[EmbeddingPrecision], ) -> None: if not MULTI_TENANT: raise ValueError("Multi-tenant is not enabled") @@ -309,13 +340,14 @@ class VespaIndex(DocumentIndex): for i, index_name in enumerate(indices): embedding_dim = embedding_dims[i] + embedding_precision = embedding_precisions[i] logger.info( f"Creating index: {index_name} with embedding dimension: {embedding_dim}" ) - schema = schema_template.replace( - DANSWER_CHUNK_REPLACEMENT_PAT, index_name - ).replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim)) + schema = _replace_template_values_in_schema( + schema_template, index_name, embedding_dim, embedding_precision + ) schema = schema.replace( TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else "" ) diff --git a/backend/onyx/document_index/vespa_constants.py b/backend/onyx/document_index/vespa_constants.py index a259aede5..82bb59198 100644 --- a/backend/onyx/document_index/vespa_constants.py +++ b/backend/onyx/document_index/vespa_constants.py @@ -6,6 +6,7 @@ from onyx.configs.app_configs import VESPA_TENANT_PORT from onyx.configs.constants import SOURCE_TYPE VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM" +EMBEDDING_PRECISION_REPLACEMENT_PAT = "EMBEDDING_PRECISION" DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME" DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT" SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER" diff --git a/backend/onyx/indexing/embedder.py b/backend/onyx/indexing/embedder.py index a692827c5..67bf56fc8 100644 --- a/backend/onyx/indexing/embedder.py +++ b/backend/onyx/indexing/embedder.py @@ -38,6 +38,7 @@ class IndexingEmbedder(ABC): api_url: str | None, api_version: str | None, deployment_name: str | None, + reduced_dimension: int | None, callback: IndexingHeartbeatInterface | None, ): self.model_name = model_name @@ -60,6 +61,7 @@ class IndexingEmbedder(ABC): api_url=api_url, api_version=api_version, deployment_name=deployment_name, + reduced_dimension=reduced_dimension, # The below are globally set, this flow always uses the indexing one server_host=INDEXING_MODEL_SERVER_HOST, server_port=INDEXING_MODEL_SERVER_PORT, @@ -87,6 +89,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): api_url: str | None = None, api_version: str | None = None, deployment_name: str | None = None, + reduced_dimension: int | None = None, callback: IndexingHeartbeatInterface | None = None, ): super().__init__( @@ -99,6 +102,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): api_url, api_version, deployment_name, + reduced_dimension, callback, ) @@ -219,6 +223,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder): api_url=search_settings.api_url, api_version=search_settings.api_version, deployment_name=search_settings.deployment_name, + reduced_dimension=search_settings.reduced_dimension, callback=callback, ) diff --git a/backend/onyx/indexing/models.py b/backend/onyx/indexing/models.py index 0c4451cc7..cffbdaa9b 100644 --- a/backend/onyx/indexing/models.py +++ b/backend/onyx/indexing/models.py @@ -5,6 +5,7 @@ from pydantic import Field from onyx.access.models import DocumentAccess from onyx.connectors.models import Document +from onyx.db.enums import EmbeddingPrecision from onyx.utils.logger import setup_logger from shared_configs.enums import EmbeddingProvider from shared_configs.model_server_models import Embedding @@ -143,10 +144,20 @@ class IndexingSetting(EmbeddingModelDetail): model_dim: int index_name: str | None multipass_indexing: bool + embedding_precision: EmbeddingPrecision + reduced_dimension: int | None = None + + background_reindex_enabled: bool = True # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} + @property + def final_embedding_dim(self) -> int: + if self.reduced_dimension: + return self.reduced_dimension + return self.model_dim + @classmethod def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting": return cls( @@ -158,6 +169,9 @@ class IndexingSetting(EmbeddingModelDetail): provider_type=search_settings.provider_type, index_name=search_settings.index_name, multipass_indexing=search_settings.multipass_indexing, + embedding_precision=search_settings.embedding_precision, + reduced_dimension=search_settings.reduced_dimension, + background_reindex_enabled=search_settings.background_reindex_enabled, ) diff --git a/backend/onyx/key_value_store/store.py b/backend/onyx/key_value_store/store.py index ab612d6a7..75f54a111 100644 --- a/backend/onyx/key_value_store/store.py +++ b/backend/onyx/key_value_store/store.py @@ -19,6 +19,7 @@ from onyx.utils.special_types import JSON_ro from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id + logger = setup_logger() diff --git a/backend/onyx/main.py b/backend/onyx/main.py index 2444e6f19..003e26fb2 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router from onyx.server.documents.connector import router as connector_router from onyx.server.documents.credential import router as credential_router from onyx.server.documents.document import router as document_router -from onyx.server.documents.standard_oauth import router as oauth_router from onyx.server.features.document_set.api import router as document_set_router from onyx.server.features.folder.api import router as folder_router from onyx.server.features.input_prompt.api import ( @@ -323,7 +322,6 @@ def get_application() -> FastAPI: ) include_router_with_global_prefix_prepended(application, long_term_logs_router) include_router_with_global_prefix_prepended(application, api_key_router) - include_router_with_global_prefix_prepended(application, oauth_router) if AUTH_TYPE == AuthType.DISABLED: # Server logs this during auth setup verification step diff --git a/backend/onyx/natural_language_processing/search_nlp_models.py b/backend/onyx/natural_language_processing/search_nlp_models.py index 5f1f2d59a..3a7fcdf6f 100644 --- a/backend/onyx/natural_language_processing/search_nlp_models.py +++ b/backend/onyx/natural_language_processing/search_nlp_models.py @@ -89,6 +89,7 @@ class EmbeddingModel: callback: IndexingHeartbeatInterface | None = None, api_version: str | None = None, deployment_name: str | None = None, + reduced_dimension: int | None = None, ) -> None: self.api_key = api_key self.provider_type = provider_type @@ -100,6 +101,7 @@ class EmbeddingModel: self.api_url = api_url self.api_version = api_version self.deployment_name = deployment_name + self.reduced_dimension = reduced_dimension self.tokenizer = get_tokenizer( model_name=model_name, provider_type=provider_type ) @@ -188,6 +190,7 @@ class EmbeddingModel: manual_query_prefix=self.query_prefix, manual_passage_prefix=self.passage_prefix, api_url=self.api_url, + reduced_dimension=self.reduced_dimension, ) start_time = time.time() @@ -300,6 +303,7 @@ class EmbeddingModel: retrim_content=retrim_content, api_version=search_settings.api_version, deployment_name=search_settings.deployment_name, + reduced_dimension=search_settings.reduced_dimension, ) diff --git a/backend/onyx/onyxbot/slack/blocks.py b/backend/onyx/onyxbot/slack/blocks.py index a096549cc..fddf7e4be 100644 --- a/backend/onyx/onyxbot/slack/blocks.py +++ b/backend/onyx/onyxbot/slack/blocks.py @@ -31,12 +31,18 @@ from onyx.onyxbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID +from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID +from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID from onyx.onyxbot.slack.formatting import format_slack_message from onyx.onyxbot.slack.icons import source_to_github_img_link +from onyx.onyxbot.slack.models import ActionValuesEphemeralMessage +from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageChannelConfig +from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageMessageInfo from onyx.onyxbot.slack.models import SlackMessageInfo from onyx.onyxbot.slack.utils import build_continue_in_web_ui_id from onyx.onyxbot.slack.utils import build_feedback_id +from onyx.onyxbot.slack.utils import build_publish_ephemeral_message_id from onyx.onyxbot.slack.utils import remove_slack_text_interactions from onyx.onyxbot.slack.utils import translate_vespa_highlight_to_slack from onyx.utils.text_processing import decode_escapes @@ -105,6 +111,77 @@ def _build_qa_feedback_block( ) +def _build_ephemeral_publication_block( + channel_id: str, + chat_message_id: int, + message_info: SlackMessageInfo, + original_question_ts: str, + channel_conf: ChannelConfig, + feedback_reminder_id: str | None = None, +) -> Block: + # check whether the message is in a thread + if ( + message_info is not None + and message_info.msg_to_respond is not None + and message_info.thread_to_respond is not None + and (message_info.msg_to_respond == message_info.thread_to_respond) + ): + respond_ts = None + else: + respond_ts = original_question_ts + + action_values_ephemeral_message_channel_config = ( + ActionValuesEphemeralMessageChannelConfig( + channel_name=channel_conf.get("channel_name"), + respond_tag_only=channel_conf.get("respond_tag_only"), + respond_to_bots=channel_conf.get("respond_to_bots"), + is_ephemeral=channel_conf.get("is_ephemeral", False), + respond_member_group_list=channel_conf.get("respond_member_group_list"), + answer_filters=channel_conf.get("answer_filters"), + follow_up_tags=channel_conf.get("follow_up_tags"), + show_continue_in_web_ui=channel_conf.get("show_continue_in_web_ui", False), + ) + ) + + action_values_ephemeral_message_message_info = ( + ActionValuesEphemeralMessageMessageInfo( + bypass_filters=message_info.bypass_filters, + channel_to_respond=message_info.channel_to_respond, + msg_to_respond=message_info.msg_to_respond, + email=message_info.email, + sender_id=message_info.sender_id, + thread_messages=[], + is_bot_msg=message_info.is_bot_msg, + is_bot_dm=message_info.is_bot_dm, + thread_to_respond=respond_ts, + ) + ) + + action_values_ephemeral_message = ActionValuesEphemeralMessage( + original_question_ts=original_question_ts, + feedback_reminder_id=feedback_reminder_id, + chat_message_id=chat_message_id, + message_info=action_values_ephemeral_message_message_info, + channel_conf=action_values_ephemeral_message_channel_config, + ) + + return ActionsBlock( + block_id=build_publish_ephemeral_message_id(original_question_ts), + elements=[ + ButtonElement( + action_id=SHOW_EVERYONE_ACTION_ID, + text="📢 Share with Everyone", + value=action_values_ephemeral_message.model_dump_json(), + ), + ButtonElement( + action_id=KEEP_TO_YOURSELF_ACTION_ID, + text="🤫 Keep to Yourself", + value=action_values_ephemeral_message.model_dump_json(), + ), + ], + ) + + def get_document_feedback_blocks() -> Block: return SectionBlock( text=( @@ -486,16 +563,21 @@ def build_slack_response_blocks( use_citations: bool, feedback_reminder_id: str | None, skip_ai_feedback: bool = False, + offer_ephemeral_publication: bool = False, expecting_search_result: bool = False, + skip_restated_question: bool = False, ) -> list[Block]: """ This function is a top level function that builds all the blocks for the Slack response. It also handles combining all the blocks together. """ # If called with the OnyxBot slash command, the question is lost so we have to reshow it - restate_question_block = get_restate_blocks( - message_info.thread_messages[-1].message, message_info.is_bot_msg - ) + if not skip_restated_question: + restate_question_block = get_restate_blocks( + message_info.thread_messages[-1].message, message_info.is_bot_msg + ) + else: + restate_question_block = [] if expecting_search_result: answer_blocks = _build_qa_response_blocks( @@ -520,12 +602,36 @@ def build_slack_response_blocks( ) follow_up_block = [] - if channel_conf and channel_conf.get("follow_up_tags") is not None: + if ( + channel_conf + and channel_conf.get("follow_up_tags") is not None + and not channel_conf.get("is_ephemeral", False) + ): follow_up_block.append( _build_follow_up_block(message_id=answer.chat_message_id) ) - ai_feedback_block = [] + publish_ephemeral_message_block = [] + + if ( + offer_ephemeral_publication + and answer.chat_message_id is not None + and message_info.msg_to_respond is not None + and channel_conf is not None + ): + publish_ephemeral_message_block.append( + _build_ephemeral_publication_block( + channel_id=message_info.channel_to_respond, + chat_message_id=answer.chat_message_id, + original_question_ts=message_info.msg_to_respond, + message_info=message_info, + channel_conf=channel_conf, + feedback_reminder_id=feedback_reminder_id, + ) + ) + + ai_feedback_block: list[Block] = [] + if answer.chat_message_id is not None and not skip_ai_feedback: ai_feedback_block.append( _build_qa_feedback_block( @@ -547,6 +653,7 @@ def build_slack_response_blocks( all_blocks = ( restate_question_block + answer_blocks + + publish_ephemeral_message_block + ai_feedback_block + citations_divider + citations_blocks diff --git a/backend/onyx/onyxbot/slack/constants.py b/backend/onyx/onyxbot/slack/constants.py index 6a5b3ed43..1f2d4ed68 100644 --- a/backend/onyx/onyxbot/slack/constants.py +++ b/backend/onyx/onyxbot/slack/constants.py @@ -2,6 +2,8 @@ from enum import Enum LIKE_BLOCK_ACTION_ID = "feedback-like" DISLIKE_BLOCK_ACTION_ID = "feedback-dislike" +SHOW_EVERYONE_ACTION_ID = "show-everyone" +KEEP_TO_YOURSELF_ACTION_ID = "keep-to-yourself" CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui" FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button" IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button" diff --git a/backend/onyx/onyxbot/slack/handlers/handle_buttons.py b/backend/onyx/onyxbot/slack/handlers/handle_buttons.py index 8a81e7fd3..e78a4cc08 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_buttons.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_buttons.py @@ -1,3 +1,4 @@ +import json from typing import Any from typing import cast @@ -5,21 +6,32 @@ from slack_sdk import WebClient from slack_sdk.models.blocks import SectionBlock from slack_sdk.models.views import View from slack_sdk.socket_mode.request import SocketModeRequest +from slack_sdk.webhook import WebhookClient +from onyx.chat.models import ChatOnyxBotResponse +from onyx.chat.models import CitationInfo +from onyx.chat.models import QADocsResponse from onyx.configs.constants import MessageType from onyx.configs.constants import SearchFeedbackType from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI from onyx.connectors.slack.utils import expert_info_from_slack_id from onyx.connectors.slack.utils import make_slack_api_rate_limited +from onyx.context.search.models import SavedSearchDoc +from onyx.db.chat import get_chat_message +from onyx.db.chat import translate_db_message_to_chat_message_detail from onyx.db.feedback import create_chat_message_feedback from onyx.db.feedback import create_doc_retrieval_feedback from onyx.db.session import get_session_with_current_tenant +from onyx.db.users import get_user_by_email from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks +from onyx.onyxbot.slack.blocks import build_slack_response_blocks from onyx.onyxbot.slack.blocks import get_document_feedback_blocks from onyx.onyxbot.slack.config import get_slack_channel_config_for_bot_and_channel from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from onyx.onyxbot.slack.constants import FeedbackVisibility +from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID +from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID from onyx.onyxbot.slack.handlers.handle_message import ( remove_scheduled_feedback_reminder, @@ -35,15 +47,48 @@ from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails from onyx.onyxbot.slack.utils import get_channel_name_from_id from onyx.onyxbot.slack.utils import get_feedback_visibility from onyx.onyxbot.slack.utils import read_slack_thread -from onyx.onyxbot.slack.utils import respond_in_thread +from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import TenantSocketModeClient from onyx.onyxbot.slack.utils import update_emote_react +from onyx.server.query_and_chat.models import ChatMessageDetail from onyx.utils.logger import setup_logger logger = setup_logger() +def _convert_db_doc_id_to_document_ids( + citation_dict: dict[int, int], top_documents: list[SavedSearchDoc] +) -> list[CitationInfo]: + citation_list_with_document_id = [] + for citation_num, db_doc_id in citation_dict.items(): + if db_doc_id is not None: + matching_doc = next( + (d for d in top_documents if d.db_doc_id == db_doc_id), None + ) + if matching_doc: + citation_list_with_document_id.append( + CitationInfo( + citation_num=citation_num, document_id=matching_doc.document_id + ) + ) + return citation_list_with_document_id + + +def _build_citation_list(chat_message_detail: ChatMessageDetail) -> list[CitationInfo]: + citation_dict = chat_message_detail.citations + if citation_dict is None: + return [] + else: + top_documents = ( + chat_message_detail.context_docs.top_documents + if chat_message_detail.context_docs + else [] + ) + citation_list = _convert_db_doc_id_to_document_ids(citation_dict, top_documents) + return citation_list + + def handle_doc_feedback_button( req: SocketModeRequest, client: TenantSocketModeClient, @@ -58,7 +103,7 @@ def handle_doc_feedback_button( external_id = build_feedback_id(query_event_id, doc_id, doc_rank) channel_id = req.payload["container"]["channel_id"] - thread_ts = req.payload["container"]["thread_ts"] + thread_ts = req.payload["container"].get("thread_ts", None) data = View( type="modal", @@ -84,7 +129,7 @@ def handle_generate_answer_button( channel_id = req.payload["channel"]["id"] channel_name = req.payload["channel"]["name"] message_ts = req.payload["message"]["ts"] - thread_ts = req.payload["container"]["thread_ts"] + thread_ts = req.payload["container"].get("thread_ts", None) user_id = req.payload["user"]["id"] expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={}) email = expert_info.email if expert_info else None @@ -106,7 +151,7 @@ def handle_generate_answer_button( # tell the user that we're working on it # Send an ephemeral message to the user that we're generating the answer - respond_in_thread( + respond_in_thread_or_channel( client=client.web_client, channel=channel_id, receiver_ids=[user_id], @@ -142,6 +187,178 @@ def handle_generate_answer_button( ) +def handle_publish_ephemeral_message_button( + req: SocketModeRequest, + client: TenantSocketModeClient, + action_id: str, +) -> None: + """ + This function handles the Share with Everyone/Keep for Yourself buttons + for ephemeral messages. + """ + channel_id = req.payload["channel"]["id"] + ephemeral_message_ts = req.payload["container"]["message_ts"] + + slack_sender_id = req.payload["user"]["id"] + response_url = req.payload["response_url"] + webhook = WebhookClient(url=response_url) + + # The additional data required that was added to buttons. + # Specifically, this contains the message_info, channel_conf information + # and some additional attributes. + value_dict = json.loads(req.payload["actions"][0]["value"]) + + original_question_ts = value_dict.get("original_question_ts") + if not original_question_ts: + raise ValueError("Missing original_question_ts in the payload") + if not ephemeral_message_ts: + raise ValueError("Missing ephemeral_message_ts in the payload") + + feedback_reminder_id = value_dict.get("feedback_reminder_id") + + slack_message_info = SlackMessageInfo(**value_dict["message_info"]) + channel_conf = value_dict.get("channel_conf") + + user_email = value_dict.get("message_info", {}).get("email") + + chat_message_id = value_dict.get("chat_message_id") + + # Obtain onyx_user and chat_message information + if not chat_message_id: + raise ValueError("Missing chat_message_id in the payload") + + with get_session_with_current_tenant() as db_session: + onyx_user = get_user_by_email(user_email, db_session) + if not onyx_user: + raise ValueError("Cannot determine onyx_user_id from email in payload") + try: + chat_message = get_chat_message(chat_message_id, onyx_user.id, db_session) + except ValueError: + chat_message = get_chat_message( + chat_message_id, None, db_session + ) # is this good idea? + except Exception as e: + logger.error(f"Failed to get chat message: {e}") + raise e + + chat_message_detail = translate_db_message_to_chat_message_detail(chat_message) + + # construct the proper citation format and then the answer in the suitable format + # we need to construct the blocks. + citation_list = _build_citation_list(chat_message_detail) + + onyx_bot_answer = ChatOnyxBotResponse( + answer=chat_message_detail.message, + citations=citation_list, + chat_message_id=chat_message_id, + docs=QADocsResponse( + top_documents=chat_message_detail.context_docs.top_documents + if chat_message_detail.context_docs + else [], + predicted_flow=None, + predicted_search=None, + applied_source_filters=None, + applied_time_cutoff=None, + recency_bias_multiplier=1.0, + ), + llm_selected_doc_indices=None, + error_msg=None, + ) + + # Note: we need to use the webhook and the respond_url to update/delete ephemeral messages + if action_id == SHOW_EVERYONE_ACTION_ID: + # Convert to non-ephemeral message in thread + try: + webhook.send( + response_type="ephemeral", + text="", + blocks=[], + replace_original=True, + delete_original=True, + ) + except Exception as e: + logger.error(f"Failed to send webhook: {e}") + + # remove handling of empheremal block and add AI feedback. + all_blocks = build_slack_response_blocks( + answer=onyx_bot_answer, + message_info=slack_message_info, + channel_conf=channel_conf, + use_citations=True, + feedback_reminder_id=feedback_reminder_id, + skip_ai_feedback=False, + offer_ephemeral_publication=False, + skip_restated_question=True, + ) + try: + # Post in thread as non-ephemeral message + respond_in_thread_or_channel( + client=client.web_client, + channel=channel_id, + receiver_ids=None, # If respond_member_group_list is set, send to them. TODO: check! + text="Hello! Onyx has some results for you!", + blocks=all_blocks, + thread_ts=original_question_ts, + # don't unfurl, since otherwise we will have 5+ previews which makes the message very long + unfurl=False, + send_as_ephemeral=False, + ) + except Exception as e: + logger.error(f"Failed to publish ephemeral message: {e}") + raise e + + elif action_id == KEEP_TO_YOURSELF_ACTION_ID: + # Keep as ephemeral message in channel or thread, but remove the publish button and add feedback button + + changed_blocks = build_slack_response_blocks( + answer=onyx_bot_answer, + message_info=slack_message_info, + channel_conf=channel_conf, + use_citations=True, + feedback_reminder_id=feedback_reminder_id, + skip_ai_feedback=False, + offer_ephemeral_publication=False, + skip_restated_question=True, + ) + + try: + if slack_message_info.thread_to_respond is not None: + # There seems to be a bug in slack where an update within the thread + # actually leads to the update to be posted in the channel. Therefore, + # for now we delete the original ephemeral message and post a new one + # if the ephemeral message is in a thread. + webhook.send( + response_type="ephemeral", + text="", + blocks=[], + replace_original=True, + delete_original=True, + ) + + respond_in_thread_or_channel( + client=client.web_client, + channel=channel_id, + receiver_ids=[slack_sender_id], + text="Your personal response, sent as an ephemeral message.", + blocks=changed_blocks, + thread_ts=original_question_ts, + # don't unfurl, since otherwise we will have 5+ previews which makes the message very long + unfurl=False, + send_as_ephemeral=True, + ) + else: + # This works fine if the ephemeral message is in the channel + webhook.send( + response_type="ephemeral", + text="Your personal response, sent as an ephemeral message.", + blocks=changed_blocks, + replace_original=True, + delete_original=False, + ) + except Exception as e: + logger.error(f"Failed to send webhook: {e}") + + def handle_slack_feedback( feedback_id: str, feedback_type: str, @@ -153,13 +370,20 @@ def handle_slack_feedback( ) -> None: message_id, doc_id, doc_rank = decompose_action_id(feedback_id) + # Get Onyx user from Slack ID + expert_info = expert_info_from_slack_id( + user_id_to_post_confirmation, client, user_cache={} + ) + email = expert_info.email if expert_info else None + with get_session_with_current_tenant() as db_session: + onyx_user = get_user_by_email(email, db_session) if email else None if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]: create_chat_message_feedback( is_positive=feedback_type == LIKE_BLOCK_ACTION_ID, feedback_text="", chat_message_id=message_id, - user_id=None, # no "user" for Slack bot for now + user_id=onyx_user.id if onyx_user else None, db_session=db_session, ) remove_scheduled_feedback_reminder( @@ -213,7 +437,7 @@ def handle_slack_feedback( else: msg = f"<@{user_id_to_post_confirmation}> has {feedback_response_txt} the AI Answer" - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=channel_id_to_post_confirmation, text=msg, @@ -232,7 +456,7 @@ def handle_followup_button( action_id = cast(str, action.get("block_id")) channel_id = req.payload["container"]["channel_id"] - thread_ts = req.payload["container"]["thread_ts"] + thread_ts = req.payload["container"].get("thread_ts", None) update_emote_react( emoji=DANSWER_FOLLOWUP_EMOJI, @@ -265,7 +489,7 @@ def handle_followup_button( blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids) - respond_in_thread( + respond_in_thread_or_channel( client=client.web_client, channel=channel_id, text="Received your request for more help", @@ -315,7 +539,7 @@ def handle_followup_resolved_button( ) -> None: channel_id = req.payload["container"]["channel_id"] message_ts = req.payload["container"]["message_ts"] - thread_ts = req.payload["container"]["thread_ts"] + thread_ts = req.payload["container"].get("thread_ts", None) clicker_name = get_clicker_name(req, client) @@ -349,7 +573,7 @@ def handle_followup_resolved_button( resolved_block = SectionBlock(text=msg_text) - respond_in_thread( + respond_in_thread_or_channel( client=client.web_client, channel=channel_id, text="Your request for help as been addressed!", diff --git a/backend/onyx/onyxbot/slack/handlers/handle_message.py b/backend/onyx/onyxbot/slack/handlers/handle_message.py index 975828b9c..d6d87cefe 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_message.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_message.py @@ -18,7 +18,7 @@ from onyx.onyxbot.slack.handlers.handle_standard_answers import ( from onyx.onyxbot.slack.models import SlackMessageInfo from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails from onyx.onyxbot.slack.utils import fetch_user_ids_from_groups -from onyx.onyxbot.slack.utils import respond_in_thread +from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import slack_usage_report from onyx.onyxbot.slack.utils import update_emote_react from onyx.utils.logger import setup_logger @@ -29,7 +29,7 @@ logger_base = setup_logger() def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None: if details.is_bot_msg and details.sender_id: - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=details.channel_to_respond, thread_ts=details.msg_to_respond, @@ -202,7 +202,7 @@ def handle_message( # which would just respond to the sender if send_to and is_bot_msg: if sender_id: - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=channel, receiver_ids=[sender_id], @@ -220,6 +220,7 @@ def handle_message( add_slack_user_if_not_exists(db_session, message_info.email) # first check if we need to respond with a standard answer + # standard answers should be published in a thread used_standard_answer = handle_standard_answers( message_info=message_info, receiver_ids=send_to, diff --git a/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py b/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py index af6551d61..303bb8544 100644 --- a/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py +++ b/backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py @@ -33,7 +33,7 @@ from onyx.onyxbot.slack.blocks import build_slack_response_blocks from onyx.onyxbot.slack.handlers.utils import send_team_member_message from onyx.onyxbot.slack.handlers.utils import slackify_message_thread from onyx.onyxbot.slack.models import SlackMessageInfo -from onyx.onyxbot.slack.utils import respond_in_thread +from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import SlackRateLimiter from onyx.onyxbot.slack.utils import update_emote_react from onyx.server.query_and_chat.models import CreateChatMessageRequest @@ -82,12 +82,38 @@ def handle_regular_answer( message_ts_to_respond_to = message_info.msg_to_respond is_bot_msg = message_info.is_bot_msg + + # Capture whether response mode for channel is ephemeral. Even if the channel is set + # to respond with an ephemeral message, we still send as non-ephemeral if + # the message is a dm with the Onyx bot. + send_as_ephemeral = ( + slack_channel_config.channel_config.get("is_ephemeral", False) + and not message_info.is_bot_dm + ) + + # If the channel mis configured to respond with an ephemeral message, + # or the message is a dm to the Onyx bot, we should use the proper onyx user from the email. + # This will make documents privately accessible to the user available to Onyx Bot answers. + # Otherwise - if not ephemeral or DM to Onyx Bot - we must use None as the user to restrict + # to public docs. + user = None - if message_info.is_bot_dm: + if message_info.is_bot_dm or send_as_ephemeral: if message_info.email: with get_session_with_current_tenant() as db_session: user = get_user_by_email(message_info.email, db_session) + target_thread_ts = ( + None + if send_as_ephemeral and len(message_info.thread_messages) < 2 + else message_ts_to_respond_to + ) + target_receiver_ids = ( + [message_info.sender_id] + if message_info.sender_id and send_as_ephemeral + else receiver_ids + ) + document_set_names: list[str] | None = None prompt = None # If no persona is specified, use the default search based persona @@ -134,11 +160,10 @@ def handle_regular_answer( history_messages = messages[:-1] single_message_history = slackify_message_thread(history_messages) or None + # Always check for ACL permissions, also for documnt sets that were explicitly added + # to the Bot by the Administrator. (Change relative to earlier behavior where all documents + # in an attached document set were available to all users in the channel.) bypass_acl = False - if slack_channel_config.persona and slack_channel_config.persona.document_sets: - # For Slack channels, use the full document set, admin will be warned when configuring it - # with non-public document sets - bypass_acl = True if not message_ts_to_respond_to and not is_bot_msg: # if the message is not "/onyx" command, then it should have a message ts to respond to @@ -219,12 +244,13 @@ def handle_regular_answer( # Optionally, respond in thread with the error message, Used primarily # for debugging purposes if should_respond_with_error_msgs: - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=channel, - receiver_ids=None, + receiver_ids=target_receiver_ids, text=f"Encountered exception when trying to answer: \n\n```{e}```", - thread_ts=message_ts_to_respond_to, + thread_ts=target_thread_ts, + send_as_ephemeral=send_as_ephemeral, ) # In case of failures, don't keep the reaction there permanently @@ -242,32 +268,36 @@ def handle_regular_answer( if answer is None: assert DISABLE_GENERATIVE_AI is True try: - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=channel, - receiver_ids=receiver_ids, + receiver_ids=target_receiver_ids, text="Hello! Onyx has some results for you!", blocks=[ SectionBlock( text="Onyx is down for maintenance.\nWe're working hard on recharging the AI!" ) ], - thread_ts=message_ts_to_respond_to, + thread_ts=target_thread_ts, + send_as_ephemeral=send_as_ephemeral, # don't unfurl, since otherwise we will have 5+ previews which makes the message very long unfurl=False, ) # For DM (ephemeral message), we need to create a thread via a normal message so the user can see # the ephemeral message. This also will give the user a notification which ephemeral message does not. - if receiver_ids: - respond_in_thread( + + # If the channel is ephemeral, we don't need to send a message to the user since they will already see the message + if target_receiver_ids and not send_as_ephemeral: + respond_in_thread_or_channel( client=client, channel=channel, text=( "👋 Hi, we've just gathered and forwarded the relevant " + "information to the team. They'll get back to you shortly!" ), - thread_ts=message_ts_to_respond_to, + thread_ts=target_thread_ts, + send_as_ephemeral=send_as_ephemeral, ) return False @@ -316,12 +346,13 @@ def handle_regular_answer( # Optionally, respond in thread with the error message # Used primarily for debugging purposes if should_respond_with_error_msgs: - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=channel, - receiver_ids=None, + receiver_ids=target_receiver_ids, text="Found no documents when trying to answer. Did you index any documents?", - thread_ts=message_ts_to_respond_to, + thread_ts=target_thread_ts, + send_as_ephemeral=send_as_ephemeral, ) return True @@ -349,15 +380,27 @@ def handle_regular_answer( # Optionally, respond in thread with the error message # Used primarily for debugging purposes if should_respond_with_error_msgs: - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=channel, - receiver_ids=None, + receiver_ids=target_receiver_ids, text="Found no citations or quotes when trying to answer.", - thread_ts=message_ts_to_respond_to, + thread_ts=target_thread_ts, + send_as_ephemeral=send_as_ephemeral, ) return True + if ( + send_as_ephemeral + and target_receiver_ids is not None + and len(target_receiver_ids) == 1 + ): + offer_ephemeral_publication = True + skip_ai_feedback = True + else: + offer_ephemeral_publication = False + skip_ai_feedback = False if feedback_reminder_id else True + all_blocks = build_slack_response_blocks( message_info=message_info, answer=answer, @@ -365,31 +408,39 @@ def handle_regular_answer( use_citations=True, # No longer supporting quotes feedback_reminder_id=feedback_reminder_id, expecting_search_result=expecting_search_result, + offer_ephemeral_publication=offer_ephemeral_publication, + skip_ai_feedback=skip_ai_feedback, ) try: - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=channel, - receiver_ids=[message_info.sender_id] - if message_info.is_bot_msg and message_info.sender_id - else receiver_ids, + receiver_ids=target_receiver_ids, text="Hello! Onyx has some results for you!", blocks=all_blocks, - thread_ts=message_ts_to_respond_to, + thread_ts=target_thread_ts, # don't unfurl, since otherwise we will have 5+ previews which makes the message very long unfurl=False, + send_as_ephemeral=send_as_ephemeral, ) # For DM (ephemeral message), we need to create a thread via a normal message so the user can see # the ephemeral message. This also will give the user a notification which ephemeral message does not. # if there is no message_ts_to_respond_to, and we have made it this far, then this is a /onyx message # so we shouldn't send_team_member_message - if receiver_ids and message_ts_to_respond_to is not None: + if ( + target_receiver_ids + and message_ts_to_respond_to is not None + and not send_as_ephemeral + and target_thread_ts is not None + ): send_team_member_message( client=client, channel=channel, - thread_ts=message_ts_to_respond_to, + thread_ts=target_thread_ts, + receiver_ids=target_receiver_ids, + send_as_ephemeral=send_as_ephemeral, ) return False diff --git a/backend/onyx/onyxbot/slack/handlers/utils.py b/backend/onyx/onyxbot/slack/handlers/utils.py index ea8ab3288..83835e87b 100644 --- a/backend/onyx/onyxbot/slack/handlers/utils.py +++ b/backend/onyx/onyxbot/slack/handlers/utils.py @@ -2,7 +2,7 @@ from slack_sdk import WebClient from onyx.chat.models import ThreadMessage from onyx.configs.constants import MessageType -from onyx.onyxbot.slack.utils import respond_in_thread +from onyx.onyxbot.slack.utils import respond_in_thread_or_channel def slackify_message_thread(messages: list[ThreadMessage]) -> str: @@ -32,8 +32,10 @@ def send_team_member_message( client: WebClient, channel: str, thread_ts: str, + receiver_ids: list[str] | None = None, + send_as_ephemeral: bool = False, ) -> None: - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=channel, text=( @@ -41,4 +43,6 @@ def send_team_member_message( + "information to the team. They'll get back to you shortly!" ), thread_ts=thread_ts, + receiver_ids=None, + send_as_ephemeral=send_as_ephemeral, ) diff --git a/backend/onyx/onyxbot/slack/listener.py b/backend/onyx/onyxbot/slack/listener.py index 343d8d1c6..dcfb012e0 100644 --- a/backend/onyx/onyxbot/slack/listener.py +++ b/backend/onyx/onyxbot/slack/listener.py @@ -57,7 +57,9 @@ from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID +from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID +from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID from onyx.onyxbot.slack.handlers.handle_buttons import handle_doc_feedback_button from onyx.onyxbot.slack.handlers.handle_buttons import handle_followup_button @@ -67,6 +69,9 @@ from onyx.onyxbot.slack.handlers.handle_buttons import ( from onyx.onyxbot.slack.handlers.handle_buttons import ( handle_generate_answer_button, ) +from onyx.onyxbot.slack.handlers.handle_buttons import ( + handle_publish_ephemeral_message_button, +) from onyx.onyxbot.slack.handlers.handle_buttons import handle_slack_feedback from onyx.onyxbot.slack.handlers.handle_message import handle_message from onyx.onyxbot.slack.handlers.handle_message import ( @@ -81,7 +86,7 @@ from onyx.onyxbot.slack.utils import get_onyx_bot_slack_bot_id from onyx.onyxbot.slack.utils import read_slack_thread from onyx.onyxbot.slack.utils import remove_onyx_bot_tag from onyx.onyxbot.slack.utils import rephrase_slack_message -from onyx.onyxbot.slack.utils import respond_in_thread +from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import TenantSocketModeClient from onyx.redis.redis_pool import get_redis_client from onyx.server.manage.models import SlackBotTokens @@ -667,7 +672,11 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) -> feedback_msg_reminder = cast(str, action.get("value")) feedback_id = cast(str, action.get("block_id")) channel_id = cast(str, req.payload["container"]["channel_id"]) - thread_ts = cast(str, req.payload["container"]["thread_ts"]) + thread_ts = cast( + str, + req.payload["container"].get("thread_ts") + or req.payload["container"].get("message_ts"), + ) else: logger.error("Unable to process feedback. Action not found") return @@ -783,7 +792,7 @@ def apologize_for_fail( details: SlackMessageInfo, client: TenantSocketModeClient, ) -> None: - respond_in_thread( + respond_in_thread_or_channel( client=client.web_client, channel=details.channel_to_respond, thread_ts=details.msg_to_respond, @@ -859,6 +868,14 @@ def action_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> No if action["action_id"] in [DISLIKE_BLOCK_ACTION_ID, LIKE_BLOCK_ACTION_ID]: # AI Answer feedback return process_feedback(req, client) + elif action["action_id"] in [ + SHOW_EVERYONE_ACTION_ID, + KEEP_TO_YOURSELF_ACTION_ID, + ]: + # Publish ephemeral message or keep hidden in main channel + return handle_publish_ephemeral_message_button( + req, client, action["action_id"] + ) elif action["action_id"] == FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID: # Activation of the "source feedback" button return handle_doc_feedback_button(req, client) diff --git a/backend/onyx/onyxbot/slack/models.py b/backend/onyx/onyxbot/slack/models.py index f3cb6add2..81b8bf1f4 100644 --- a/backend/onyx/onyxbot/slack/models.py +++ b/backend/onyx/onyxbot/slack/models.py @@ -1,3 +1,5 @@ +from typing import Literal + from pydantic import BaseModel from onyx.chat.models import ThreadMessage @@ -13,3 +15,37 @@ class SlackMessageInfo(BaseModel): bypass_filters: bool # User has tagged @OnyxBot is_bot_msg: bool # User is using /OnyxBot is_bot_dm: bool # User is direct messaging to OnyxBot + + +# Models used to encode the relevant data for the ephemeral message actions +class ActionValuesEphemeralMessageMessageInfo(BaseModel): + bypass_filters: bool | None + channel_to_respond: str | None + msg_to_respond: str | None + email: str | None + sender_id: str | None + thread_messages: list[ThreadMessage] | None + is_bot_msg: bool | None + is_bot_dm: bool | None + thread_to_respond: str | None + + +class ActionValuesEphemeralMessageChannelConfig(BaseModel): + channel_name: str | None + respond_tag_only: bool | None + respond_to_bots: bool | None + is_ephemeral: bool + respond_member_group_list: list[str] | None + answer_filters: list[ + Literal["well_answered_postfilter", "questionmark_prefilter"] + ] | None + follow_up_tags: list[str] | None + show_continue_in_web_ui: bool + + +class ActionValuesEphemeralMessage(BaseModel): + original_question_ts: str | None + feedback_reminder_id: str | None + chat_message_id: int + message_info: ActionValuesEphemeralMessageMessageInfo + channel_conf: ActionValuesEphemeralMessageChannelConfig diff --git a/backend/onyx/onyxbot/slack/utils.py b/backend/onyx/onyxbot/slack/utils.py index dc06d6bbe..f232bb08e 100644 --- a/backend/onyx/onyxbot/slack/utils.py +++ b/backend/onyx/onyxbot/slack/utils.py @@ -184,7 +184,7 @@ def _build_error_block(error_message: str) -> Block: backoff=2, logger=cast(logging.Logger, logger), ) -def respond_in_thread( +def respond_in_thread_or_channel( client: WebClient, channel: str, thread_ts: str | None, @@ -193,6 +193,7 @@ def respond_in_thread( receiver_ids: list[str] | None = None, metadata: Metadata | None = None, unfurl: bool = True, + send_as_ephemeral: bool | None = True, ) -> list[str]: if not text and not blocks: raise ValueError("One of `text` or `blocks` must be provided") @@ -236,6 +237,7 @@ def respond_in_thread( message_ids.append(response["message_ts"]) else: slack_call = make_slack_api_rate_limited(client.chat_postEphemeral) + for receiver in receiver_ids: try: response = slack_call( @@ -299,6 +301,12 @@ def build_feedback_id( return unique_prefix + ID_SEPARATOR + feedback_id +def build_publish_ephemeral_message_id( + original_question_ts: str, +) -> str: + return "publish_ephemeral_message__" + original_question_ts + + def build_continue_in_web_ui_id( message_id: int, ) -> str: @@ -539,7 +547,7 @@ def read_slack_thread( # If auto-detected filters are on, use the second block for the actual answer # The first block is the auto-detected filters - if message.startswith("_Filters"): + if message is not None and message.startswith("_Filters"): if len(blocks) < 2: logger.warning(f"Only filter blocks found: {reply}") continue @@ -611,7 +619,7 @@ class SlackRateLimiter: def notify( self, client: WebClient, channel: str, position: int, thread_ts: str | None ) -> None: - respond_in_thread( + respond_in_thread_or_channel( client=client, channel=channel, receiver_ids=None, diff --git a/backend/onyx/server/documents/credential.py b/backend/onyx/server/documents/credential.py index 473a7fbe5..76f08dd30 100644 --- a/backend/onyx/server/documents/credential.py +++ b/backend/onyx/server/documents/credential.py @@ -13,6 +13,7 @@ from onyx.db.credentials import cleanup_gmail_credentials from onyx.db.credentials import create_credential from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE from onyx.db.credentials import delete_credential +from onyx.db.credentials import delete_credential_for_user from onyx.db.credentials import fetch_credential_by_id_for_user from onyx.db.credentials import fetch_credentials_by_source_for_user from onyx.db.credentials import fetch_credentials_for_user @@ -88,7 +89,7 @@ def delete_credential_by_id_admin( db_session: Session = Depends(get_session), ) -> StatusResponse: """Same as the user endpoint, but can delete any credential (not just the user's own)""" - delete_credential(db_session=db_session, credential_id=credential_id, user=None) + delete_credential(db_session=db_session, credential_id=credential_id) return StatusResponse( success=True, message="Credential deleted successfully", data=credential_id ) @@ -242,7 +243,7 @@ def delete_credential_by_id( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse: - delete_credential( + delete_credential_for_user( credential_id, user, db_session, @@ -259,7 +260,7 @@ def force_delete_credential_by_id( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse: - delete_credential(credential_id, user, db_session, True) + delete_credential_for_user(credential_id, user, db_session, True) return StatusResponse( success=True, message="Credential deleted successfully", data=credential_id diff --git a/backend/onyx/server/features/folder/api.py b/backend/onyx/server/features/folder/api.py index c558540fe..56252ed40 100644 --- a/backend/onyx/server/features/folder/api.py +++ b/backend/onyx/server/features/folder/api.py @@ -49,6 +49,7 @@ def get_folders( name=chat_session.description, persona_id=chat_session.persona_id, time_created=chat_session.time_created.isoformat(), + time_updated=chat_session.time_updated.isoformat(), shared_status=chat_session.shared_status, folder_id=folder.id, ) diff --git a/backend/onyx/server/manage/models.py b/backend/onyx/server/manage/models.py index 786b9074a..cf51a7b08 100644 --- a/backend/onyx/server/manage/models.py +++ b/backend/onyx/server/manage/models.py @@ -181,6 +181,7 @@ class SlackChannelConfigCreationRequest(BaseModel): channel_name: str respond_tag_only: bool = False respond_to_bots: bool = False + is_ephemeral: bool = False show_continue_in_web_ui: bool = False enable_auto_filters: bool = False # If no team members, assume respond in the channel to everyone diff --git a/backend/onyx/server/manage/search_settings.py b/backend/onyx/server/manage/search_settings.py index 8935a3124..663f62f4d 100644 --- a/backend/onyx/server/manage/search_settings.py +++ b/backend/onyx/server/manage/search_settings.py @@ -72,11 +72,13 @@ def set_new_search_settings( and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX) ): index_name += ALT_INDEX_SUFFIX - search_values = search_settings_new.dict() + search_values = search_settings_new.model_dump() search_values["index_name"] = index_name new_search_settings_request = SavedSearchSettings(**search_values) else: - new_search_settings_request = SavedSearchSettings(**search_settings_new.dict()) + new_search_settings_request = SavedSearchSettings( + **search_settings_new.model_dump() + ) secondary_search_settings = get_secondary_search_settings(db_session) @@ -103,8 +105,10 @@ def set_new_search_settings( document_index = get_default_document_index(search_settings, new_search_settings) document_index.ensure_indices_exist( - index_embedding_dim=search_settings.model_dim, - secondary_index_embedding_dim=new_search_settings.model_dim, + primary_embedding_dim=search_settings.final_embedding_dim, + primary_embedding_precision=search_settings.embedding_precision, + secondary_index_embedding_dim=new_search_settings.final_embedding_dim, + secondary_index_embedding_precision=new_search_settings.embedding_precision, ) # Pause index attempts for the currently in use index to preserve resources @@ -137,6 +141,17 @@ def cancel_new_embedding( db_session=db_session, ) + # remove the old index from the vector db + primary_search_settings = get_current_search_settings(db_session) + document_index = get_default_document_index(primary_search_settings, None) + document_index.ensure_indices_exist( + primary_embedding_dim=primary_search_settings.final_embedding_dim, + primary_embedding_precision=primary_search_settings.embedding_precision, + # just finished swap, no more secondary index + secondary_index_embedding_dim=None, + secondary_index_embedding_precision=None, + ) + @router.delete("/delete-search-settings") def delete_search_settings_endpoint( diff --git a/backend/onyx/server/manage/slack_bot.py b/backend/onyx/server/manage/slack_bot.py index ebca360aa..52ebf18b7 100644 --- a/backend/onyx/server/manage/slack_bot.py +++ b/backend/onyx/server/manage/slack_bot.py @@ -71,6 +71,15 @@ def _form_channel_config( "also respond to a predetermined set of users." ) + if ( + slack_channel_config_creation_request.is_ephemeral + and slack_channel_config_creation_request.respond_member_group_list + ): + raise ValueError( + "Cannot set OnyxBot to respond to users in a private (ephemeral) message " + "and also respond to a selected list of users." + ) + channel_config: ChannelConfig = { "channel_name": cleaned_channel_name, } @@ -91,6 +100,8 @@ def _form_channel_config( "respond_to_bots" ] = slack_channel_config_creation_request.respond_to_bots + channel_config["is_ephemeral"] = slack_channel_config_creation_request.is_ephemeral + channel_config["disabled"] = slack_channel_config_creation_request.disabled return channel_config @@ -343,7 +354,8 @@ def list_bot_configs( ] -MAX_CHANNELS = 200 +MAX_SLACK_PAGES = 5 +SLACK_API_CHANNELS_PER_PAGE = 100 @router.get( @@ -355,8 +367,8 @@ def get_all_channels_from_slack_api( _: User | None = Depends(current_admin_user), ) -> list[SlackChannel]: """ - Fetches all channels from the Slack API. - If the workspace has 200 or more channels, we raise an error. + Fetches channels the bot is a member of from the Slack API. + Handles pagination with a limit to avoid excessive API calls. """ tokens = fetch_slack_bot_tokens(db_session, bot_id) if not tokens or "bot_token" not in tokens: @@ -365,28 +377,60 @@ def get_all_channels_from_slack_api( ) client = WebClient(token=tokens["bot_token"]) + all_channels = [] + next_cursor = None + current_page = 0 try: - response = client.conversations_list( - types="public_channel,private_channel", - exclude_archived=True, - limit=MAX_CHANNELS, - ) + # Use users_conversations with limited pagination + while current_page < MAX_SLACK_PAGES: + current_page += 1 + + # Make API call with cursor if we have one + if next_cursor: + response = client.users_conversations( + types="public_channel,private_channel", + exclude_archived=True, + cursor=next_cursor, + limit=SLACK_API_CHANNELS_PER_PAGE, + ) + else: + response = client.users_conversations( + types="public_channel,private_channel", + exclude_archived=True, + limit=SLACK_API_CHANNELS_PER_PAGE, + ) + + # Add channels to our list + if "channels" in response and response["channels"]: + all_channels.extend(response["channels"]) + + # Check if we need to paginate + if ( + "response_metadata" in response + and "next_cursor" in response["response_metadata"] + ): + next_cursor = response["response_metadata"]["next_cursor"] + if next_cursor: + if current_page == MAX_SLACK_PAGES: + raise HTTPException( + status_code=400, + detail="Workspace has too many channels to paginate over in this call.", + ) + continue + + # If we get here, no more pages + break channels = [ SlackChannel(id=channel["id"], name=channel["name"]) - for channel in response["channels"] + for channel in all_channels ] - if len(channels) == MAX_CHANNELS: - raise HTTPException( - status_code=400, - detail=f"Workspace has {MAX_CHANNELS} or more channels.", - ) - return channels except SlackApiError as e: + # Handle rate limiting or other API errors raise HTTPException( status_code=500, detail=f"Error fetching channels from Slack API: {str(e)}", diff --git a/backend/onyx/server/openai_assistants_api/threads_api.py b/backend/onyx/server/openai_assistants_api/threads_api.py index 9951a950d..4db15984f 100644 --- a/backend/onyx/server/openai_assistants_api/threads_api.py +++ b/backend/onyx/server/openai_assistants_api/threads_api.py @@ -147,9 +147,11 @@ def list_threads( name=chat.description, persona_id=chat.persona_id, time_created=chat.time_created.isoformat(), + time_updated=chat.time_updated.isoformat(), shared_status=chat.shared_status, folder_id=chat.folder_id, current_alternate_model=chat.current_alternate_model, + current_temperature_override=chat.temperature_override, ) for chat in chat_sessions ] diff --git a/backend/onyx/server/query_and_chat/chat_backend.py b/backend/onyx/server/query_and_chat/chat_backend.py index 650d2aa0a..86988e887 100644 --- a/backend/onyx/server/query_and_chat/chat_backend.py +++ b/backend/onyx/server/query_and_chat/chat_backend.py @@ -119,6 +119,7 @@ def get_user_chat_sessions( name=chat.description, persona_id=chat.persona_id, time_created=chat.time_created.isoformat(), + time_updated=chat.time_updated.isoformat(), shared_status=chat.shared_status, folder_id=chat.folder_id, current_alternate_model=chat.current_alternate_model, diff --git a/backend/onyx/server/query_and_chat/models.py b/backend/onyx/server/query_and_chat/models.py index 2c68b38f1..132be33ca 100644 --- a/backend/onyx/server/query_and_chat/models.py +++ b/backend/onyx/server/query_and_chat/models.py @@ -181,6 +181,7 @@ class ChatSessionDetails(BaseModel): name: str | None persona_id: int | None = None time_created: str + time_updated: str shared_status: ChatSessionSharedStatus folder_id: int | None = None current_alternate_model: str | None = None @@ -241,6 +242,7 @@ class ChatMessageDetail(BaseModel): files: list[FileDescriptor] tool_call: ToolCallFinalResult | None refined_answer_improvement: bool | None = None + is_agentic: bool | None = None error: str | None = None def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore diff --git a/backend/onyx/server/query_and_chat/query_backend.py b/backend/onyx/server/query_and_chat/query_backend.py index b296456de..27df69d5a 100644 --- a/backend/onyx/server/query_and_chat/query_backend.py +++ b/backend/onyx/server/query_and_chat/query_backend.py @@ -159,6 +159,7 @@ def get_user_search_sessions( name=sessions_with_documents_dict[search.id], persona_id=search.persona_id, time_created=search.time_created.isoformat(), + time_updated=search.time_updated.isoformat(), shared_status=search.shared_status, folder_id=search.folder_id, current_alternate_model=search.current_alternate_model, diff --git a/backend/onyx/server/utils.py b/backend/onyx/server/utils.py index 8dc7a429b..8d8643a51 100644 --- a/backend/onyx/server/utils.py +++ b/backend/onyx/server/utils.py @@ -46,13 +46,21 @@ def mask_string(sensitive_str: str) -> str: return "****...**" + sensitive_str[-4:] +MASK_CREDENTIALS_WHITELIST = { + DB_CREDENTIALS_AUTHENTICATION_METHOD, + "wiki_base", + "cloud_name", + "cloud_id", +} + + def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]: masked_creds = {} for key, val in credential_dict.items(): if isinstance(val, str): # we want to pass the authentication_method field through so the frontend # can disambiguate credentials created by different methods - if key == DB_CREDENTIALS_AUTHENTICATION_METHOD: + if key in MASK_CREDENTIALS_WHITELIST: masked_creds[key] = val else: masked_creds[key] = mask_string(val) @@ -63,8 +71,8 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]: continue raise ValueError( - f"Unable to mask credentials of type other than string, cannot process request." - f"Recieved type: {type(val)}" + f"Unable to mask credentials of type other than string or int, cannot process request." + f"Received type: {type(val)}" ) return masked_creds diff --git a/backend/onyx/setup.py b/backend/onyx/setup.py index fc4b9f983..1dff601ef 100644 --- a/backend/onyx/setup.py +++ b/backend/onyx/setup.py @@ -21,6 +21,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs from onyx.db.connector_credential_pair import resync_cc_pair from onyx.db.credentials import create_initial_public_credential from onyx.db.document import check_docs_exist +from onyx.db.enums import EmbeddingPrecision from onyx.db.index_attempt import cancel_indexing_attempts_past_model from onyx.db.index_attempt import expire_index_attempts from onyx.db.llm import fetch_default_provider @@ -32,7 +33,7 @@ from onyx.db.search_settings import get_current_search_settings from onyx.db.search_settings import get_secondary_search_settings from onyx.db.search_settings import update_current_search_settings from onyx.db.search_settings import update_secondary_search_settings -from onyx.db.swap_index import check_index_swap +from onyx.db.swap_index import check_and_perform_index_swap from onyx.document_index.factory import get_default_document_index from onyx.document_index.interfaces import DocumentIndex from onyx.document_index.vespa.index import VespaIndex @@ -73,7 +74,7 @@ def setup_onyx( The Tenant Service calls the tenants/create endpoint which runs this. """ - check_index_swap(db_session=db_session) + check_and_perform_index_swap(db_session=db_session) active_search_settings = get_active_search_settings(db_session) search_settings = active_search_settings.primary @@ -243,10 +244,18 @@ def setup_vespa( try: logger.notice(f"Setting up Vespa (attempt {x+1}/{num_attempts})...") document_index.ensure_indices_exist( - index_embedding_dim=index_setting.model_dim, - secondary_index_embedding_dim=secondary_index_setting.model_dim - if secondary_index_setting - else None, + primary_embedding_dim=index_setting.final_embedding_dim, + primary_embedding_precision=index_setting.embedding_precision, + secondary_index_embedding_dim=( + secondary_index_setting.final_embedding_dim + if secondary_index_setting + else None + ), + secondary_index_embedding_precision=( + secondary_index_setting.embedding_precision + if secondary_index_setting + else None + ), ) logger.notice("Vespa setup complete.") @@ -360,6 +369,11 @@ def setup_vespa_multitenant(supported_indices: list[SupportedEmbeddingModel]) -> ], embedding_dims=[index.dim for index in supported_indices] + [index.dim for index in supported_indices], + # on the cloud, just use float for all indices, the option to change this + # is not exposed to the user + embedding_precisions=[ + EmbeddingPrecision.FLOAT for _ in range(len(supported_indices) * 2) + ], ) logger.notice("Vespa setup complete.") diff --git a/backend/scripts/query_time_check/seed_dummy_docs.py b/backend/scripts/query_time_check/seed_dummy_docs.py index f0d838415..6d33ae902 100644 --- a/backend/scripts/query_time_check/seed_dummy_docs.py +++ b/backend/scripts/query_time_check/seed_dummy_docs.py @@ -136,7 +136,7 @@ def seed_dummy_docs( search_settings = get_current_search_settings(db_session) multipass_config = get_multipass_config(search_settings) index_name = search_settings.index_name - embedding_dim = search_settings.model_dim + embedding_dim = search_settings.final_embedding_dim vespa_index = VespaIndex( index_name=index_name, diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index 9f7e853d2..644f315fa 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -30,6 +30,12 @@ class EmbedRequest(BaseModel): manual_passage_prefix: str | None = None api_url: str | None = None api_version: str | None = None + + # allows for the truncation of the vector to a lower dimension + # to reduce memory usage. Currently only supported for OpenAI models. + # will be ignored for other providers. + reduced_dimension: int | None = None + # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} diff --git a/backend/tests/daily/connectors/confluence/test_confluence_basic.py b/backend/tests/daily/connectors/confluence/test_confluence_basic.py index 26d86c557..7cc80fb2f 100644 --- a/backend/tests/daily/connectors/confluence/test_confluence_basic.py +++ b/backend/tests/daily/connectors/confluence/test_confluence_basic.py @@ -5,7 +5,9 @@ from unittest.mock import patch import pytest +from onyx.configs.constants import DocumentSource from onyx.connectors.confluence.connector import ConfluenceConnector +from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider from onyx.connectors.models import Document @@ -18,12 +20,15 @@ def confluence_connector() -> ConfluenceConnector: page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""), ) - connector.load_credentials( + credentials_provider = OnyxStaticCredentialsProvider( + None, + DocumentSource.CONFLUENCE, { "confluence_username": os.environ["CONFLUENCE_USER_NAME"], "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], - } + }, ) + connector.set_credentials_provider(credentials_provider) return connector diff --git a/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py b/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py index 0f66a993d..6bb43437e 100644 --- a/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py +++ b/backend/tests/daily/connectors/confluence/test_confluence_permissions_basic.py @@ -2,7 +2,9 @@ import os import pytest +from onyx.configs.constants import DocumentSource from onyx.connectors.confluence.connector import ConfluenceConnector +from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider @pytest.fixture @@ -11,12 +13,16 @@ def confluence_connector() -> ConfluenceConnector: wiki_base="https://danswerai.atlassian.net", is_cloud=True, ) - connector.load_credentials( + + credentials_provider = OnyxStaticCredentialsProvider( + None, + DocumentSource.CONFLUENCE, { - "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], "confluence_username": os.environ["CONFLUENCE_USER_NAME"], - } + "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"], + }, ) + connector.set_credentials_provider(credentials_provider) return connector diff --git a/backend/tests/integration/common_utils/reset.py b/backend/tests/integration/common_utils/reset.py index 6c306bb46..0c89766b9 100644 --- a/backend/tests/integration/common_utils/reset.py +++ b/backend/tests/integration/common_utils/reset.py @@ -16,7 +16,7 @@ from onyx.db.engine import SYNC_DB_API from onyx.db.search_settings import get_current_search_settings from onyx.db.session import get_session_context_manager from onyx.db.session import get_session_with_tenant -from onyx.db.swap_index import check_index_swap +from onyx.db.swap_index import check_and_perform_index_swap from onyx.db.tenant import get_all_tenant_ids from onyx.document_index.document_index_utils import get_multipass_config from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT @@ -194,7 +194,7 @@ def reset_vespa() -> None: with get_session_context_manager() as db_session: # swap to the correct default model - check_index_swap(db_session) + check_and_perform_index_swap(db_session) search_settings = get_current_search_settings(db_session) multipass_config = get_multipass_config(search_settings) @@ -289,7 +289,7 @@ def reset_vespa_multitenant() -> None: for tenant_id in get_all_tenant_ids(): with get_session_with_tenant(tenant_id=tenant_id) as db_session: # swap to the correct default model for each tenant - check_index_swap(db_session) + check_and_perform_index_swap(db_session) search_settings = get_current_search_settings(db_session) multipass_config = get_multipass_config(search_settings) diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 4c8b2be0c..282f5cdc2 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -108,3 +108,13 @@ def admin_user() -> DATestUser: @pytest.fixture def reset_multitenant() -> None: reset_all_multitenant() + + +def pytest_runtest_logstart(nodeid: str, location: tuple[str, int | None, str]) -> None: + print(f"\nTest start: {nodeid}") + + +def pytest_runtest_logfinish( + nodeid: str, location: tuple[str, int | None, str] +) -> None: + print(f"\nTest end: {nodeid}") diff --git a/backend/tests/integration/tests/pruning/test_pruning.py b/backend/tests/integration/tests/pruning/test_pruning.py index 96153db94..997c84cad 100644 --- a/backend/tests/integration/tests/pruning/test_pruning.py +++ b/backend/tests/integration/tests/pruning/test_pruning.py @@ -142,8 +142,12 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None: selected_cc_pair = CCPairManager.get_indexing_status_by_id( cc_pair_1.id, user_performing_action=admin_user ) + assert selected_cc_pair is not None, "cc_pair not found after indexing!" - assert selected_cc_pair.docs_indexed == 15 + + # used to be 15, but now + # localhost:8889/ and localhost:8889/index.html are deduped + assert selected_cc_pair.docs_indexed == 14 logger.info("Removing about.html.") os.remove(os.path.join(website_tgt, "about.html")) @@ -160,24 +164,29 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None: cc_pair_1.id, user_performing_action=admin_user ) assert selected_cc_pair is not None, "cc_pair not found after pruning!" - assert selected_cc_pair.docs_indexed == 13 + assert selected_cc_pair.docs_indexed == 12 # check vespa + root_id = f"http://{hostname}:{port}/" index_id = f"http://{hostname}:{port}/index.html" about_id = f"http://{hostname}:{port}/about.html" courses_id = f"http://{hostname}:{port}/courses.html" - doc_ids = [index_id, about_id, courses_id] + doc_ids = [root_id, index_id, about_id, courses_id] retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"] retrieved_docs = { doc["fields"]["document_id"]: doc["fields"] for doc in retrieved_docs_dict } - # verify index.html exists in Vespa - retrieved_doc = retrieved_docs.get(index_id) + # verify root exists in Vespa + retrieved_doc = retrieved_docs.get(root_id) assert retrieved_doc + # verify index.html does not exist in Vespa since it is a duplicate of root + retrieved_doc = retrieved_docs.get(index_id) + assert not retrieved_doc + # verify about and courses do not exist retrieved_doc = retrieved_docs.get(about_id) assert not retrieved_doc diff --git a/backend/tests/unit/model_server/test_embedding.py b/backend/tests/unit/model_server/test_embedding.py index 6781ab27a..17068f3a6 100644 --- a/backend/tests/unit/model_server/test_embedding.py +++ b/backend/tests/unit/model_server/test_embedding.py @@ -64,7 +64,7 @@ async def test_openai_embedding( embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) result = await embedding._embed_openai( - ["test1", "test2"], "text-embedding-ada-002" + ["test1", "test2"], "text-embedding-ada-002", None ) assert result == sample_embeddings @@ -89,6 +89,7 @@ async def test_embed_text_cloud_provider() -> None: prefix=None, api_url=None, api_version=None, + reduced_dimension=None, ) assert result == [[0.1, 0.2], [0.3, 0.4]] @@ -114,6 +115,7 @@ async def test_embed_text_local_model() -> None: prefix=None, api_url=None, api_version=None, + reduced_dimension=None, ) assert result == [[0.1, 0.2], [0.3, 0.4]] @@ -157,6 +159,7 @@ async def test_rate_limit_handling() -> None: prefix=None, api_url=None, api_version=None, + reduced_dimension=None, ) @@ -179,6 +182,7 @@ async def test_concurrent_embeddings() -> None: manual_passage_prefix=None, api_url=None, api_version=None, + reduced_dimension=None, ) with patch("model_server.encoders.get_embedding_model") as mock_get_model: diff --git a/backend/tests/unit/onyx/connectors/confluence/test_rate_limit_handler.py b/backend/tests/unit/onyx/connectors/confluence/test_rate_limit_handler.py index ed77d7764..c7e88e5dd 100644 --- a/backend/tests/unit/onyx/connectors/confluence/test_rate_limit_handler.py +++ b/backend/tests/unit/onyx/connectors/confluence/test_rate_limit_handler.py @@ -3,9 +3,7 @@ from unittest.mock import Mock import pytest from requests import HTTPError -from onyx.connectors.confluence.onyx_confluence import ( - handle_confluence_rate_limit, -) +from onyx.connectors.confluence.utils import handle_confluence_rate_limit @pytest.fixture @@ -50,6 +48,8 @@ def mock_confluence_call() -> Mock: # mock_sleep.assert_called_with(int(retry_after)) +# NOTE(rkuo): This tests an older version of rate limiting that is being deprecated +# and probably should go away soon. def test_non_rate_limit_error(mock_confluence_call: Mock) -> None: mock_confluence_call.side_effect = HTTPError( response=Mock(status_code=500, text="Internal Server Error") diff --git a/web/package-lock.json b/web/package-lock.json index e0fd1e981..069bcebd3 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -52,7 +52,7 @@ "lodash": "^4.17.21", "lucide-react": "^0.454.0", "mdast-util-find-and-replace": "^3.0.1", - "next": "^15.0.2", + "next": "^15.2.0", "next-themes": "^0.4.4", "npm": "^10.8.0", "postcss": "^8.4.31", @@ -2631,9 +2631,10 @@ } }, "node_modules/@next/env": { - "version": "15.0.2", - "resolved": "https://registry.npmjs.org/@next/env/-/env-15.0.2.tgz", - "integrity": "sha512-c0Zr0ModK5OX7D4ZV8Jt/wqoXtitLNPwUfG9zElCZztdaZyNVnN40rDXVZ/+FGuR4CcNV5AEfM6N8f+Ener7Dg==" + "version": "15.2.0", + "resolved": "https://registry.npmjs.org/@next/env/-/env-15.2.0.tgz", + "integrity": "sha512-eMgJu1RBXxxqqnuRJQh5RozhskoNUDHBFybvi+Z+yK9qzKeG7dadhv/Vp1YooSZmCnegf7JxWuapV77necLZNA==", + "license": "MIT" }, "node_modules/@next/eslint-plugin-next": { "version": "14.2.3", @@ -2645,12 +2646,13 @@ } }, "node_modules/@next/swc-darwin-arm64": { - "version": "15.0.2", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.0.2.tgz", - "integrity": "sha512-GK+8w88z+AFlmt+ondytZo2xpwlfAR8U6CRwXancHImh6EdGfHMIrTSCcx5sOSBei00GyLVL0ioo1JLKTfprgg==", + "version": "15.2.0", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.2.0.tgz", + "integrity": "sha512-rlp22GZwNJjFCyL7h5wz9vtpBVuCt3ZYjFWpEPBGzG712/uL1bbSkS675rVAUCRZ4hjoTJ26Q7IKhr5DfJrHDA==", "cpu": [ "arm64" ], + "license": "MIT", "optional": true, "os": [ "darwin" @@ -2660,12 +2662,13 @@ } }, "node_modules/@next/swc-darwin-x64": { - "version": "15.0.2", - "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.0.2.tgz", - "integrity": "sha512-KUpBVxIbjzFiUZhiLIpJiBoelqzQtVZbdNNsehhUn36e2YzKHphnK8eTUW1s/4aPy5kH/UTid8IuVbaOpedhpw==", + "version": "15.2.0", + "resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.2.0.tgz", + "integrity": "sha512-DiU85EqSHogCz80+sgsx90/ecygfCSGl5P3b4XDRVZpgujBm5lp4ts7YaHru7eVTyZMjHInzKr+w0/7+qDrvMA==", "cpu": [ "x64" ], + "license": "MIT", "optional": true, "os": [ "darwin" @@ -2675,12 +2678,13 @@ } }, "node_modules/@next/swc-linux-arm64-gnu": { - "version": "15.0.2", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.0.2.tgz", - "integrity": "sha512-9J7TPEcHNAZvwxXRzOtiUvwtTD+fmuY0l7RErf8Yyc7kMpE47MIQakl+3jecmkhOoIyi/Rp+ddq7j4wG6JDskQ==", + "version": "15.2.0", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.2.0.tgz", + "integrity": "sha512-VnpoMaGukiNWVxeqKHwi8MN47yKGyki5q+7ql/7p/3ifuU2341i/gDwGK1rivk0pVYbdv5D8z63uu9yMw0QhpQ==", "cpu": [ "arm64" ], + "license": "MIT", "optional": true, "os": [ "linux" @@ -2690,12 +2694,13 @@ } }, "node_modules/@next/swc-linux-arm64-musl": { - "version": "15.0.2", - "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.0.2.tgz", - "integrity": "sha512-BjH4ZSzJIoTTZRh6rG+a/Ry4SW0HlizcPorqNBixBWc3wtQtj4Sn9FnRZe22QqrPnzoaW0ctvSz4FaH4eGKMww==", + "version": "15.2.0", + "resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.2.0.tgz", + "integrity": "sha512-ka97/ssYE5nPH4Qs+8bd8RlYeNeUVBhcnsNUmFM6VWEob4jfN9FTr0NBhXVi1XEJpj3cMfgSRW+LdE3SUZbPrw==", "cpu": [ "arm64" ], + "license": "MIT", "optional": true, "os": [ "linux" @@ -2705,12 +2710,13 @@ } }, "node_modules/@next/swc-linux-x64-gnu": { - "version": "15.0.2", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.0.2.tgz", - "integrity": "sha512-i3U2TcHgo26sIhcwX/Rshz6avM6nizrZPvrDVDY1bXcLH1ndjbO8zuC7RoHp0NSK7wjJMPYzm7NYL1ksSKFreA==", + "version": "15.2.0", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.2.0.tgz", + "integrity": "sha512-zY1JduE4B3q0k2ZCE+DAF/1efjTXUsKP+VXRtrt/rJCTgDlUyyryx7aOgYXNc1d8gobys/Lof9P9ze8IyRDn7Q==", "cpu": [ "x64" ], + "license": "MIT", "optional": true, "os": [ "linux" @@ -2720,12 +2726,13 @@ } }, "node_modules/@next/swc-linux-x64-musl": { - "version": "15.0.2", - "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.0.2.tgz", - "integrity": "sha512-AMfZfSVOIR8fa+TXlAooByEF4OB00wqnms1sJ1v+iu8ivwvtPvnkwdzzFMpsK5jA2S9oNeeQ04egIWVb4QWmtQ==", + "version": "15.2.0", + "resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.2.0.tgz", + "integrity": "sha512-QqvLZpurBD46RhaVaVBepkVQzh8xtlUN00RlG4Iq1sBheNugamUNPuZEH1r9X1YGQo1KqAe1iiShF0acva3jHQ==", "cpu": [ "x64" ], + "license": "MIT", "optional": true, "os": [ "linux" @@ -2735,12 +2742,13 @@ } }, "node_modules/@next/swc-win32-arm64-msvc": { - "version": "15.0.2", - "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.0.2.tgz", - "integrity": "sha512-JkXysDT0/hEY47O+Hvs8PbZAeiCQVxKfGtr4GUpNAhlG2E0Mkjibuo8ryGD29Qb5a3IOnKYNoZlh/MyKd2Nbww==", + "version": "15.2.0", + "resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.2.0.tgz", + "integrity": "sha512-ODZ0r9WMyylTHAN6pLtvUtQlGXBL9voljv6ujSlcsjOxhtXPI1Ag6AhZK0SE8hEpR1374WZZ5w33ChpJd5fsjw==", "cpu": [ "arm64" ], + "license": "MIT", "optional": true, "os": [ "win32" @@ -2750,12 +2758,13 @@ } }, "node_modules/@next/swc-win32-x64-msvc": { - "version": "15.0.2", - "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.0.2.tgz", - "integrity": "sha512-foaUL0NqJY/dX0Pi/UcZm5zsmSk5MtP/gxx3xOPyREkMFN+CTjctPfu3QaqrQHinaKdPnMWPJDKt4VjDfTBe/Q==", + "version": "15.2.0", + "resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.2.0.tgz", + "integrity": "sha512-8+4Z3Z7xa13NdUuUAcpVNA6o76lNPniBd9Xbo02bwXQXnZgFvEopwY2at5+z7yHl47X9qbZpvwatZ2BRo3EdZw==", "cpu": [ "x64" ], + "license": "MIT", "optional": true, "os": [ "win32" @@ -7389,11 +7398,12 @@ "integrity": "sha512-e2BR4lsJkkRlKZ/qCHPw9ZaSxc0MVUd7gtbtaB7aMvHeJVYe8sOB8DBZkP2DtISHGSku9sCK6T6cnY0CtXrOCQ==" }, "node_modules/@swc/helpers": { - "version": "0.5.13", - "resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.13.tgz", - "integrity": "sha512-UoKGxQ3r5kYI9dALKJapMmuK+1zWM/H17Z1+iwnNmzcJRnfFuevZs375TA5rW31pu4BS4NoSy1fRsexDXfWn5w==", + "version": "0.5.15", + "resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.15.tgz", + "integrity": "sha512-JQ5TuMi45Owi4/BIMAJBoSQoOJu12oOk/gADqlcUL9JEdHB8vyjUSsxqeNXnmXHjYKMi2WcYtezGEEhqUI/E2g==", + "license": "Apache-2.0", "dependencies": { - "tslib": "^2.4.0" + "tslib": "^2.8.0" } }, "node_modules/@tailwindcss/typography": { @@ -14966,13 +14976,14 @@ "integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==" }, "node_modules/next": { - "version": "15.0.2", - "resolved": "https://registry.npmjs.org/next/-/next-15.0.2.tgz", - "integrity": "sha512-rxIWHcAu4gGSDmwsELXacqAPUk+j8dV/A9cDF5fsiCMpkBDYkO2AEaL1dfD+nNmDiU6QMCFN8Q30VEKapT9UHQ==", + "version": "15.2.0", + "resolved": "https://registry.npmjs.org/next/-/next-15.2.0.tgz", + "integrity": "sha512-VaiM7sZYX8KIAHBrRGSFytKknkrexNfGb8GlG6e93JqueCspuGte8i4ybn8z4ww1x3f2uzY4YpTaBEW4/hvsoQ==", + "license": "MIT", "dependencies": { - "@next/env": "15.0.2", + "@next/env": "15.2.0", "@swc/counter": "0.1.3", - "@swc/helpers": "0.5.13", + "@swc/helpers": "0.5.15", "busboy": "1.6.0", "caniuse-lite": "^1.0.30001579", "postcss": "8.4.31", @@ -14982,25 +14993,25 @@ "next": "dist/bin/next" }, "engines": { - "node": ">=18.18.0" + "node": "^18.18.0 || ^19.8.0 || >= 20.0.0" }, "optionalDependencies": { - "@next/swc-darwin-arm64": "15.0.2", - "@next/swc-darwin-x64": "15.0.2", - "@next/swc-linux-arm64-gnu": "15.0.2", - "@next/swc-linux-arm64-musl": "15.0.2", - "@next/swc-linux-x64-gnu": "15.0.2", - "@next/swc-linux-x64-musl": "15.0.2", - "@next/swc-win32-arm64-msvc": "15.0.2", - "@next/swc-win32-x64-msvc": "15.0.2", + "@next/swc-darwin-arm64": "15.2.0", + "@next/swc-darwin-x64": "15.2.0", + "@next/swc-linux-arm64-gnu": "15.2.0", + "@next/swc-linux-arm64-musl": "15.2.0", + "@next/swc-linux-x64-gnu": "15.2.0", + "@next/swc-linux-x64-musl": "15.2.0", + "@next/swc-win32-arm64-msvc": "15.2.0", + "@next/swc-win32-x64-msvc": "15.2.0", "sharp": "^0.33.5" }, "peerDependencies": { "@opentelemetry/api": "^1.1.0", "@playwright/test": "^1.41.2", "babel-plugin-react-compiler": "*", - "react": "^18.2.0 || 19.0.0-rc-02c0e824-20241028", - "react-dom": "^18.2.0 || 19.0.0-rc-02c0e824-20241028", + "react": "^18.2.0 || 19.0.0-rc-de68d2f4-20241204 || ^19.0.0", + "react-dom": "^18.2.0 || 19.0.0-rc-de68d2f4-20241204 || ^19.0.0", "sass": "^1.3.0" }, "peerDependenciesMeta": { @@ -20590,9 +20601,10 @@ } }, "node_modules/tslib": { - "version": "2.6.2", - "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz", - "integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q==" + "version": "2.8.1", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz", + "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", + "license": "0BSD" }, "node_modules/type-check": { "version": "0.4.0", diff --git a/web/package.json b/web/package.json index 22d1df72a..080764eef 100644 --- a/web/package.json +++ b/web/package.json @@ -55,7 +55,7 @@ "lodash": "^4.17.21", "lucide-react": "^0.454.0", "mdast-util-find-and-replace": "^3.0.1", - "next": "^15.0.2", + "next": "^15.2.0", "next-themes": "^0.4.4", "npm": "^10.8.0", "postcss": "^8.4.31", diff --git a/web/src/app/admin/assistants/LabelManagement.tsx b/web/src/app/admin/assistants/LabelManagement.tsx index bdb3690cd..7dc9eaff5 100644 --- a/web/src/app/admin/assistants/LabelManagement.tsx +++ b/web/src/app/admin/assistants/LabelManagement.tsx @@ -100,11 +100,6 @@ export default function LabelManagement() { width="w-full max-w-xs" name={`editLabelName_${label.id}`} label="Label Name" - value={ - values.editLabelId === label.id - ? values.editLabelName - : label.name - } onChange={(e) => { setFieldValue("editLabelId", label.id); setFieldValue("editLabelName", e.target.value); diff --git a/web/src/app/admin/assistants/PersonaTable.tsx b/web/src/app/admin/assistants/PersonaTable.tsx index 1177e91ac..44f77f385 100644 --- a/web/src/app/admin/assistants/PersonaTable.tsx +++ b/web/src/app/admin/assistants/PersonaTable.tsx @@ -163,7 +163,7 @@ export function PersonasTable() { {popup} {deleteModalOpen && personaToDelete && ( handleInputChange(index, e.target.value)} className="flex-grow" removeLabel diff --git a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx index 503771d98..d1826e955 100644 --- a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx +++ b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx @@ -83,6 +83,8 @@ export const SlackChannelConfigCreationForm = ({ respond_tag_only: existingSlackChannelConfig?.channel_config?.respond_tag_only || false, + is_ephemeral: + existingSlackChannelConfig?.channel_config?.is_ephemeral || false, respond_to_bots: existingSlackChannelConfig?.channel_config?.respond_to_bots || false, @@ -135,6 +137,7 @@ export const SlackChannelConfigCreationForm = ({ questionmark_prefilter_enabled: Yup.boolean().required(), respond_tag_only: Yup.boolean().required(), respond_to_bots: Yup.boolean().required(), + is_ephemeral: Yup.boolean().required(), show_continue_in_web_ui: Yup.boolean().required(), enable_auto_filters: Yup.boolean().required(), respond_member_group_list: Yup.array().of(Yup.string()).required(), diff --git a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx index 13a80f6bb..1e82f04b9 100644 --- a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx +++ b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx @@ -199,17 +199,17 @@ export function SlackChannelConfigFormFields({ Default Configuration -

+

This default configuration will apply across all Slack channels the bot is added to in the Slack workspace, as well as direct messages (DMs), unless disabled.

-
+
-

+

Warning: Disabling the default configuration means the bot won't respond in Slack channels or DMs unless explicitly configured for them. @@ -238,20 +238,28 @@ export function SlackChannelConfigFormFields({ />

) : ( - - {({ field, form }: { field: any; form: any }) => ( - { - form.setFieldValue("channel_name", selected.name); - }} - initialSearchTerm={field.value} - onSearchTermChange={(term) => { - form.setFieldValue("channel_name", term); - }} - /> - )} - + <> + + {({ field, form }: { field: any; form: any }) => ( + { + form.setFieldValue("channel_name", selected.name); + }} + initialSearchTerm={field.value} + onSearchTermChange={(term) => { + form.setFieldValue("channel_name", term); + }} + /> + )} + +

+ Note: This list shows public and private channels where the + bot is a member (up to 500 channels). If you don't see a + channel, make sure the bot is added to that channel in Slack + first, or type the channel name manually. +

+ )} )} @@ -589,6 +597,13 @@ export function SlackChannelConfigFormFields({ label="Respond to Bot messages" tooltip="If not set, OnyxBot will always ignore messages from Bots" /> +

- Please note that at least one of the documents accessible by - your OnyxBot is marked as private and may contain sensitive - information. These documents will be accessible to all users - of this OnyxBot. Ensure this aligns with your intended - document sharing policy. + Please note that if the private (ephemeral) response is *not + selected*, only public documents within the selected document + sets will be accessible for user queries. If the private + (ephemeral) response *is selected*, user quries can also + leverage documents that the user has already been granted + access to. Note that users will be able to share the response + with others in the channel, so please ensure that this is + aligned with your company sharing policies.

diff --git a/web/src/app/admin/bots/[bot-id]/lib.ts b/web/src/app/admin/bots/[bot-id]/lib.ts index 5c72c9f20..d9058d188 100644 --- a/web/src/app/admin/bots/[bot-id]/lib.ts +++ b/web/src/app/admin/bots/[bot-id]/lib.ts @@ -14,6 +14,7 @@ interface SlackChannelConfigCreationRequest { answer_validity_check_enabled: boolean; questionmark_prefilter_enabled: boolean; respond_tag_only: boolean; + is_ephemeral: boolean; respond_to_bots: boolean; show_continue_in_web_ui: boolean; respond_member_group_list: string[]; @@ -45,6 +46,7 @@ const buildRequestBodyFromCreationRequest = ( channel_name: creationRequest.channel_name, respond_tag_only: creationRequest.respond_tag_only, respond_to_bots: creationRequest.respond_to_bots, + is_ephemeral: creationRequest.is_ephemeral, show_continue_in_web_ui: creationRequest.show_continue_in_web_ui, enable_auto_filters: creationRequest.enable_auto_filters, respond_member_group_list: creationRequest.respond_member_group_list, diff --git a/web/src/app/admin/configuration/document-processing/page.tsx b/web/src/app/admin/configuration/document-processing/page.tsx index 8cb011d83..98240143c 100644 --- a/web/src/app/admin/configuration/document-processing/page.tsx +++ b/web/src/app/admin/configuration/document-processing/page.tsx @@ -71,7 +71,7 @@ function Main() {

Learn more about Unstructured{" "}

@@ -141,30 +139,46 @@ export default function UpgradingPage({ {connectors && connectors.length > 0 ? ( - <> - {failedIndexingStatus && failedIndexingStatus.length > 0 && ( - - )} + futureEmbeddingModel.background_reindex_enabled ? ( + <> + {failedIndexingStatus && failedIndexingStatus.length > 0 && ( + + )} - - The table below shows the re-indexing progress of all existing - connectors. Once all connectors have been re-indexed - successfully, the new model will be used for all search - queries. Until then, we will use the old model so that no - downtime is necessary during this transition. - + + The table below shows the re-indexing progress of all + existing connectors. Once all connectors have been + re-indexed successfully, the new model will be used for all + search queries. Until then, we will use the old model so + that no downtime is necessary during this transition. + - {sortedReindexingProgress ? ( - - ) : ( - - )} - + {sortedReindexingProgress ? ( + + ) : ( + + )} + + ) : ( +
+

+ Switching Embedding Models +

+

+ You're currently switching embedding models, and + you've selected the instant switch option. The + transition will complete shortly. +

+

+ The new model will be active soon. +

+
+ ) ) : (

diff --git a/web/src/app/admin/connector/[ccPairId]/page.tsx b/web/src/app/admin/connector/[ccPairId]/page.tsx index b21949f12..0f7c33574 100644 --- a/web/src/app/admin/connector/[ccPairId]/page.tsx +++ b/web/src/app/admin/connector/[ccPairId]/page.tsx @@ -455,15 +455,15 @@ function Main({ ccPairId }: { ccPairId: number }) { Indexing Attempts

{indexAttemptErrors && indexAttemptErrors.total_items > 0 && ( - - - + + + Some documents failed to index - + {isResolvingErrors ? ( - + Resolving failures @@ -471,7 +471,7 @@ function Main({ ccPairId }: { ccPairId: number }) { <> We ran into some issues while processing some documents.{" "} setShowIndexAttemptErrors(true)} > View details. diff --git a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx index 44eec87df..4524a32df 100644 --- a/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx +++ b/web/src/app/admin/connectors/[connector]/AddConnectorPage.tsx @@ -193,13 +193,15 @@ export default function AddConnector({ // Check if there are no credentials const noCredentials = credentialTemplate == null; - if (noCredentials && 1 != formStep) { - setFormStep(Math.max(1, formStep)); - } + useEffect(() => { + if (noCredentials && 1 != formStep) { + setFormStep(Math.max(1, formStep)); + } - if (!noCredentials && !credentialActivated && formStep != 0) { - setFormStep(Math.min(formStep, 0)); - } + if (!noCredentials && !credentialActivated && formStep != 0) { + setFormStep(Math.min(formStep, 0)); + } + }, [noCredentials, formStep, setFormStep]); const convertStringToDateTime = (indexingStart: string | null) => { return indexingStart ? new Date(indexingStart) : null; diff --git a/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx b/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx index 8032c0d7b..eb32c7a87 100644 --- a/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx +++ b/web/src/app/admin/connectors/[connector]/oauth/callback/page.tsx @@ -33,7 +33,7 @@ export default function OAuthCallbackPage() { const connector = pathname?.split("/")[3]; useEffect(() => { - const handleOAuthCallback = async () => { + const onFirstLoad = async () => { // Examples // connector (url segment)= "google-drive" // sourceType (for looking up metadata) = "google_drive" @@ -85,10 +85,19 @@ export default function OAuthCallbackPage() { } setStatusMessage("Success!"); - setStatusDetails( - `Your authorization with ${sourceMetadata.displayName} completed successfully.` - ); - setRedirectUrl(response.redirect_on_success); // Extract the redirect URL + + // set the continuation link + if (response.finalize_url) { + setRedirectUrl(response.finalize_url); + setStatusDetails( + `Your authorization with ${sourceMetadata.displayName} completed successfully. Additional steps are required to complete credential setup.` + ); + } else { + setRedirectUrl(response.redirect_on_success); + setStatusDetails( + `Your authorization with ${sourceMetadata.displayName} completed successfully.` + ); + } setIsError(false); } catch (error) { console.error("OAuth error:", error); @@ -100,15 +109,15 @@ export default function OAuthCallbackPage() { } }; - handleOAuthCallback(); + onFirstLoad(); }, [code, state, connector]); return ( -
+
} /> -
- +
+

{statusMessage}

{statusDetails}

{redirectUrl && !isError && ( diff --git a/web/src/app/admin/connectors/[connector]/oauth/finalize/page.tsx b/web/src/app/admin/connectors/[connector]/oauth/finalize/page.tsx new file mode 100644 index 000000000..f4e24473c --- /dev/null +++ b/web/src/app/admin/connectors/[connector]/oauth/finalize/page.tsx @@ -0,0 +1,293 @@ +"use client"; + +import { useEffect, useState } from "react"; +import { usePathname, useRouter, useSearchParams } from "next/navigation"; +import { AdminPageTitle } from "@/components/admin/Title"; +import { Button } from "@/components/ui/button"; +import Title from "@/components/ui/title"; +import { KeyIcon } from "@/components/icons/icons"; +import { getSourceMetadata, isValidSource } from "@/lib/sources"; +import { ConfluenceAccessibleResource, ValidSources } from "@/lib/types"; +import CardSection from "@/components/admin/CardSection"; +import { + handleOAuthAuthorizationResponse, + handleOAuthConfluenceFinalize, + handleOAuthPrepareFinalization, +} from "@/lib/oauth_utils"; +import { SelectorFormField } from "@/components/admin/connectors/Field"; +import { ErrorMessage, Field, Form, Formik, useFormikContext } from "formik"; +import * as Yup from "yup"; + +// Helper component to keep the effect logic clean: +function UpdateCloudURLOnCloudIdChange({ + accessibleResources, +}: { + accessibleResources: ConfluenceAccessibleResource[]; +}) { + const { values, setValues, setFieldValue } = useFormikContext<{ + cloud_id: string; + cloud_name: string; + cloud_url: string; + }>(); + + useEffect(() => { + // Whenever cloud_id changes, find the matching resource and update cloud_url + if (values.cloud_id) { + const selectedResource = accessibleResources.find( + (resource) => resource.id === values.cloud_id + ); + if (selectedResource) { + // Update multiple fields together ... somehow setting them in sequence + // doesn't work with the validator + // it may also be possible to await each setFieldValue call. + // https://github.com/jaredpalmer/formik/issues/2266 + setValues((prevValues) => ({ + ...prevValues, + cloud_name: selectedResource.name, + cloud_url: selectedResource.url, + })); + } + } + }, [values.cloud_id, accessibleResources, setFieldValue]); + + // This component doesn't render anything visible: + return null; +} + +export default function OAuthFinalizePage() { + const router = useRouter(); + const searchParams = useSearchParams(); + + const [statusMessage, setStatusMessage] = useState("Processing..."); + const [statusDetails, setStatusDetails] = useState( + "Please wait while we complete the setup." + ); + const [redirectUrl, setRedirectUrl] = useState(null); + const [isError, setIsError] = useState(false); + const [isSubmitted, setIsSubmitted] = useState(false); // New state + const [pageTitle, setPageTitle] = useState( + "Finalize Authorization with Third-Party service" + ); + + const [accessibleResources, setAccessibleResources] = useState< + ConfluenceAccessibleResource[] + >([]); + + // Extract query parameters + const credentialParam = searchParams.get("credential"); + const credential = credentialParam ? parseInt(credentialParam, 10) : NaN; + const pathname = usePathname(); + const connector = pathname?.split("/")[3]; + + useEffect(() => { + const onFirstLoad = async () => { + // Examples + // connector (url segment)= "google-drive" + // sourceType (for looking up metadata) = "google_drive" + + if (isNaN(credential)) { + setStatusMessage("Improperly formed OAuth finalization request."); + setStatusDetails("Invalid or missing credential id."); + setIsError(true); + return; + } + + const sourceType = connector.replaceAll("-", "_"); + if (!isValidSource(sourceType)) { + setStatusMessage( + `The specified connector source type ${sourceType} does not exist.` + ); + setStatusDetails(`${sourceType} is not a valid source type.`); + setIsError(true); + return; + } + + const sourceMetadata = getSourceMetadata(sourceType as ValidSources); + setPageTitle(`Finalize Authorization with ${sourceMetadata.displayName}`); + + setStatusMessage("Processing..."); + setStatusDetails( + "Please wait while we retrieve a list of your accessible sites." + ); + setIsError(false); // Ensure no error state during loading + + try { + const response = await handleOAuthPrepareFinalization( + connector, + credential + ); + + if (!response) { + throw new Error("Empty response from OAuth server."); + } + + setAccessibleResources(response.accessible_resources); + + setStatusMessage("Select a Confluence site"); + setStatusDetails(""); + + setIsError(false); + } catch (error) { + console.error("OAuth finalization error:", error); + setStatusMessage("Oops, something went wrong!"); + setStatusDetails( + "An error occurred during the OAuth finalization process. Please try again." + ); + setIsError(true); + } + }; + + onFirstLoad(); + }, [credential, connector]); + + useEffect(() => {}, [redirectUrl]); + + return ( +
+ } /> + +
+ +

{statusMessage}

+

{statusDetails}

+ + { + formikHelpers.setSubmitting(true); + try { + if (!values.cloud_id) { + throw new Error("Cloud ID is required."); + } + + if (!values.cloud_name) { + throw new Error("Cloud URL is required."); + } + + if (!values.cloud_url) { + throw new Error("Cloud URL is required."); + } + + const response = await handleOAuthConfluenceFinalize( + values.credential_id, + values.cloud_id, + values.cloud_name, + values.cloud_url + ); + formikHelpers.setSubmitting(false); + + if (response) { + setRedirectUrl(response.redirect_url); + setStatusMessage("Confluence authorization finalized."); + } + + setIsSubmitted(true); // Mark as submitted + } catch (error) { + console.error(error); + setStatusMessage("Error during submission."); + setStatusDetails( + "An error occurred during the submission process. Please try again." + ); + setIsError(true); + formikHelpers.setSubmitting(false); + } + }} + > + {({ isSubmitting, isValid, setFieldValue }) => ( +
+ {/* Debug info +
+
+                    isValid: {String(isValid)}
+                    errors: {JSON.stringify(errors, null, 2)}
+                    values: {JSON.stringify(values, null, 2)}
+                  
+
*/} + + {/* Our helper component that reacts to changes in cloud_id */} + + + + + + + + + {!redirectUrl && accessibleResources.length > 0 && ( + ({ + name: `${resource.name} - ${resource.url}`, + value: resource.id, + }))} + onSelect={(selectedValue) => { + const selectedResource = accessibleResources.find( + (resource) => resource.id === selectedValue + ); + if (selectedResource) { + setFieldValue("cloud_id", selectedResource.id); + } + }} + /> + )} +
+ {!redirectUrl && ( + + )} + + )} +
+ + {redirectUrl && !isError && ( +
+ )} + +
+
+ ); +} diff --git a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx index 89e01be60..e46e4191e 100644 --- a/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/ConnectorInput/NumberInput.tsx @@ -1,5 +1,5 @@ -import { SubLabel } from "@/components/admin/connectors/Field"; -import { Field } from "formik"; +import { Label, SubLabel } from "@/components/admin/connectors/Field"; +import { ErrorMessage, useField } from "formik"; export default function NumberInput({ label, @@ -14,18 +14,36 @@ export default function NumberInput({ description?: string; showNeverIfZero?: boolean; }) { + const [field, meta, helpers] = useField(name); + + const handleChange = (e: React.ChangeEvent) => { + // If the input is empty, set the value to undefined or null + // This prevents the "NaN from empty string" error + if (e.target.value === "") { + helpers.setValue(undefined); + } else { + helpers.setValue(Number(e.target.value)); + } + }; + return (
- + {description && {description}} - +
); } diff --git a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx index 53a91a38f..7268ae0ca 100644 --- a/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/DynamicConnectorCreationForm.tsx @@ -1,4 +1,10 @@ -import React, { Dispatch, FC, SetStateAction, useState } from "react"; +import React, { + Dispatch, + FC, + SetStateAction, + useEffect, + useState, +} from "react"; import CredentialSubText from "@/components/credentials/CredentialFields"; import { ConnectionConfiguration } from "@/lib/connectors/connectors"; import { TextFormField } from "@/components/admin/connectors/Field"; @@ -8,6 +14,7 @@ import { AccessTypeGroupSelector } from "@/components/admin/connectors/AccessTyp import { ConfigurableSources } from "@/lib/types"; import { Credential } from "@/lib/connectors/credentials"; import { RenderField } from "./FieldRendering"; +import { useFormikContext } from "formik"; export interface DynamicConnectionFormProps { config: ConnectionConfiguration; @@ -22,7 +29,25 @@ const DynamicConnectionForm: FC = ({ connector, currentCredential, }) => { + const { setFieldValue } = useFormikContext(); // Get Formik's context functions + const [showAdvancedOptions, setShowAdvancedOptions] = useState(false); + const [connectorNameInitialized, setConnectorNameInitialized] = + useState(false); + + let initialConnectorName = ""; + if (config.initialConnectorName) { + initialConnectorName = + currentCredential?.credential_json?.[config.initialConnectorName] ?? ""; + } + + useEffect(() => { + const field_value = values["name"]; + if (initialConnectorName && !connectorNameInitialized && !field_value) { + setFieldValue("name", initialConnectorName); + setConnectorNameInitialized(true); + } + }, [initialConnectorName, setFieldValue, values]); return ( <> diff --git a/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx b/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx index 41cb446aa..a7577f662 100644 --- a/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/FieldRendering.tsx @@ -1,4 +1,4 @@ -import React, { Dispatch, FC, SetStateAction } from "react"; +import React, { Dispatch, FC, SetStateAction, useEffect } from "react"; import { AdminBooleanFormField } from "@/components/credentials/CredentialFields"; import { FileUpload } from "@/components/admin/connectors/FileUpload"; import { TabOption } from "@/lib/connectors/connectors"; @@ -16,6 +16,7 @@ import { TabsList, TabsTrigger, } from "@/components/ui/fully_wrapped_tabs"; +import { useFormikContext } from "formik"; interface TabsFieldProps { tabField: TabOption; @@ -123,6 +124,8 @@ export const RenderField: FC = ({ connector, currentCredential, }) => { + const { setFieldValue } = useFormikContext(); // Get Formik's context functions + const label = typeof field.label === "function" ? field.label(currentCredential) @@ -131,6 +134,22 @@ export const RenderField: FC = ({ typeof field.description === "function" ? field.description(currentCredential) : field.description; + const disabled = + typeof field.disabled === "function" + ? field.disabled(currentCredential) + : (field.disabled ?? false); + const initialValue = + typeof field.initial === "function" + ? field.initial(currentCredential) + : (field.initial ?? ""); + + // if initialValue exists, prepopulate the field with it + useEffect(() => { + const field_value = values[field.name]; + if (initialValue && field_value === undefined) { + setFieldValue(field.name, initialValue); + } + }, [field.name, initialValue, setFieldValue, values]); if (field.type === "tab") { return ( @@ -176,6 +195,8 @@ export const RenderField: FC = ({ subtext={description} name={field.name} label={label} + disabled={disabled} + onChange={(e) => setFieldValue(field.name, e.target.value)} /> ) : field.type === "text" ? ( = ({ name={field.name} isTextArea={field.isTextArea || false} defaultHeight={"h-15"} + disabled={disabled} + onChange={(e) => setFieldValue(field.name, e.target.value)} /> ) : field.type === "string_tab" ? (
{description}
diff --git a/web/src/app/admin/connectors/[connector]/pages/formelements/NumberInput.tsx b/web/src/app/admin/connectors/[connector]/pages/formelements/NumberInput.tsx deleted file mode 100644 index 9e1cf8dcf..000000000 --- a/web/src/app/admin/connectors/[connector]/pages/formelements/NumberInput.tsx +++ /dev/null @@ -1,42 +0,0 @@ -import { SubLabel } from "@/components/admin/connectors/Field"; -import { Field } from "formik"; - -export default function NumberInput({ - label, - value, - optional, - description, - name, - showNeverIfZero, -}: { - value?: number; - label: string; - name: string; - optional?: boolean; - description?: string; - showNeverIfZero?: boolean; -}) { - return ( -
- - {description && {description}} - - -
- ); -} diff --git a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx index 3c9f8e59f..40e2da9cb 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gdrive/Credential.tsx @@ -1,6 +1,6 @@ import { Button } from "@/components/Button"; import { PopupSpec } from "@/components/admin/connectors/Popup"; -import { useState } from "react"; +import React, { useState, useEffect } from "react"; import { useSWRConfig } from "swr"; import * as Yup from "yup"; import { useRouter } from "next/navigation"; @@ -17,13 +17,18 @@ import { GoogleDriveCredentialJson, GoogleDriveServiceAccountCredentialJson, } from "@/lib/connectors/credentials"; +import { refreshAllGoogleData } from "@/lib/googleConnector"; +import { ValidSources } from "@/lib/types"; +import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib"; type GoogleDriveCredentialJsonTypes = "authorized_user" | "service_account"; export const DriveJsonUpload = ({ setPopup, + onSuccess, }: { setPopup: (popupSpec: PopupSpec | null) => void; + onSuccess?: () => void; }) => { const { mutate } = useSWRConfig(); const [credentialJsonStr, setCredentialJsonStr] = useState< @@ -62,7 +67,6 @@ export const DriveJsonUpload = ({ + {isAdmin ? ( + <> +
+ If you want to update these credentials, delete the existing + credentials through the button below, and then upload a new + credentials JSON. +
+ + + ) : ( +
+ To change these credentials, please contact an administrator. +
+ )}
); } @@ -276,14 +327,14 @@ export const GmailJsonUploadSection = ({ > here {" "} - to either (1) setup a google OAuth App in your company workspace or (2) + to either (1) setup a Google OAuth App in your company workspace or (2) create a Service Account.

Download the credentials JSON if choosing option (1) or the Service - Account key JSON if chooosing option (2), and upload it here. + Account key JSON if choosing option (2), and upload it here.

- +
); }; @@ -299,6 +350,34 @@ interface DriveCredentialSectionProps { user: User | null; } +async function handleRevokeAccess( + connectorExists: boolean, + setPopup: (popupSpec: PopupSpec | null) => void, + existingCredential: + | Credential + | Credential, + refreshCredentials: () => void +) { + if (connectorExists) { + const message = + "Cannot revoke the Gmail credential while any connector is still associated with the credential. " + + "Please delete all associated connectors, then try again."; + setPopup({ + message: message, + type: "error", + }); + return; + } + + await adminDeleteCredential(existingCredential.id); + setPopup({ + message: "Successfully revoked the Gmail credential!", + type: "success", + }); + + refreshCredentials(); +} + export const GmailAuthSection = ({ gmailPublicCredential, gmailServiceAccountCredential, @@ -310,31 +389,49 @@ export const GmailAuthSection = ({ user, }: DriveCredentialSectionProps) => { const router = useRouter(); + const [isAuthenticating, setIsAuthenticating] = useState(false); + const [localServiceAccountData, setLocalServiceAccountData] = useState( + serviceAccountKeyData + ); + const [localAppCredentialData, setLocalAppCredentialData] = + useState(appCredentialData); + const [localGmailPublicCredential, setLocalGmailPublicCredential] = useState( + gmailPublicCredential + ); + const [ + localGmailServiceAccountCredential, + setLocalGmailServiceAccountCredential, + ] = useState(gmailServiceAccountCredential); + + // Update local state when props change + useEffect(() => { + setLocalServiceAccountData(serviceAccountKeyData); + setLocalAppCredentialData(appCredentialData); + setLocalGmailPublicCredential(gmailPublicCredential); + setLocalGmailServiceAccountCredential(gmailServiceAccountCredential); + }, [ + serviceAccountKeyData, + appCredentialData, + gmailPublicCredential, + gmailServiceAccountCredential, + ]); const existingCredential = - gmailPublicCredential || gmailServiceAccountCredential; + localGmailPublicCredential || localGmailServiceAccountCredential; if (existingCredential) { return ( <>

- Existing credential already set up! + Uploaded and authenticated credential already exists!

-
- - )} - - + } catch (error) { + setPopup({ + message: `Failed to create service account credential - ${error}`, + type: "error", + }); + } finally { + formikHelpers.setSubmitting(false); + } + }} + > + {({ isSubmitting }) => ( +
+ +
+ +
+ + )} +
); } - if (appCredentialData?.client_id) { + if (localAppCredentialData?.client_id) { return (

Next, you must provide credentials via OAuth. This gives us read - access to the docs you have access to in your gmail account. + access to the emails you have access to in your Gmail account.

); @@ -449,8 +556,8 @@ export const GmailAuthSection = ({ // case where no keys have been uploaded in step 1 return (

- Please upload an OAuth or Service Account Credential JSON in Step 1 before - moving onto Step 2. + Please upload either a OAuth Client Credential JSON or a Gmail Service + Account Key JSON in Step 1 before moving onto Step 2.

); }; diff --git a/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx b/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx index 28af99441..75b97bcba 100644 --- a/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx +++ b/web/src/app/admin/connectors/[connector]/pages/gmail/GmailPage.tsx @@ -1,10 +1,11 @@ "use client"; -import useSWR from "swr"; -import { errorHandlingFetcher } from "@/lib/fetcher"; +import React from "react"; +import { FetchError } from "@/lib/fetcher"; +import { ErrorCallout } from "@/components/ErrorCallout"; import { LoadingAnimation } from "@/components/Loading"; -import { usePopup } from "@/components/admin/connectors/Popup"; -import { CCPairBasicInfo } from "@/lib/types"; +import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup"; +import { CCPairBasicInfo, ValidSources } from "@/lib/types"; import { Credential, GmailCredentialJson, @@ -14,26 +15,33 @@ import { GmailAuthSection, GmailJsonUploadSection } from "./Credential"; import { usePublicCredentials, useBasicConnectorStatus } from "@/lib/hooks"; import Title from "@/components/ui/title"; import { useUser } from "@/components/user/UserProvider"; +import { + useGoogleAppCredential, + useGoogleServiceAccountKey, + useGoogleCredentials, + useConnectorsByCredentialId, + checkCredentialsFetched, + filterUploadedCredentials, + checkConnectorsExist, + refreshAllGoogleData, +} from "@/lib/googleConnector"; export const GmailMain = () => { const { isAdmin, user } = useUser(); + const { popup, setPopup } = usePopup(); const { data: appCredentialData, isLoading: isAppCredentialLoading, error: isAppCredentialError, - } = useSWR<{ client_id: string }>( - "/api/manage/admin/connector/gmail/app-credential", - errorHandlingFetcher - ); + } = useGoogleAppCredential("gmail"); + const { data: serviceAccountKeyData, isLoading: isServiceAccountKeyLoading, error: isServiceAccountKeyError, - } = useSWR<{ service_account_email: string }>( - "/api/manage/admin/connector/gmail/service-account-key", - errorHandlingFetcher - ); + } = useGoogleServiceAccountKey("gmail"); + const { data: connectorIndexingStatuses, isLoading: isConnectorIndexingStatusesLoading, @@ -47,20 +55,45 @@ export const GmailMain = () => { refreshCredentials, } = usePublicCredentials(); - const { popup, setPopup } = usePopup(); + const { + data: gmailCredentials, + isLoading: isGmailCredentialsLoading, + error: gmailCredentialsError, + } = useGoogleCredentials(ValidSources.Gmail); - const appCredentialSuccessfullyFetched = - appCredentialData || - (isAppCredentialError && isAppCredentialError.status === 404); - const serviceAccountKeySuccessfullyFetched = - serviceAccountKeyData || - (isServiceAccountKeyError && isServiceAccountKeyError.status === 404); + const { credential_id, uploadedCredentials } = + filterUploadedCredentials(gmailCredentials); + + const { + data: gmailConnectors, + isLoading: isGmailConnectorsLoading, + error: gmailConnectorsError, + refreshConnectorsByCredentialId, + } = useConnectorsByCredentialId(credential_id); + + const { + appCredentialSuccessfullyFetched, + serviceAccountKeySuccessfullyFetched, + } = checkCredentialsFetched( + appCredentialData, + isAppCredentialError, + serviceAccountKeyData, + isServiceAccountKeyError + ); + + const handleRefresh = () => { + refreshCredentials(); + refreshConnectorsByCredentialId(); + refreshAllGoogleData(ValidSources.Gmail); + }; if ( (!appCredentialSuccessfullyFetched && isAppCredentialLoading) || (!serviceAccountKeySuccessfullyFetched && isServiceAccountKeyLoading) || (!connectorIndexingStatuses && isConnectorIndexingStatusesLoading) || - (!credentialsData && isCredentialsLoading) + (!credentialsData && isCredentialsLoading) || + (!gmailCredentials && isGmailCredentialsLoading) || + (!gmailConnectors && isGmailConnectorsLoading) ) { return (
@@ -70,19 +103,15 @@ export const GmailMain = () => { } if (credentialsError || !credentialsData) { - return ( -
-
Failed to load credentials.
-
- ); + return ; + } + + if (gmailCredentialsError || !gmailCredentials) { + return ; } if (connectorIndexingStatusesError || !connectorIndexingStatuses) { - return ( -
-
Failed to load connectors.
-
- ); + return ; } if ( @@ -90,21 +119,28 @@ export const GmailMain = () => { !serviceAccountKeySuccessfullyFetched ) { return ( -
-
- Error loading Gmail app credentials. Contact an administrator. -
-
+ ); } - const gmailPublicCredential: Credential | undefined = - credentialsData.find( - (credential) => - (credential.credential_json?.google_service_account_key || - credential.credential_json?.google_tokens) && - credential.admin_public + if (gmailConnectorsError) { + return ( + ); + } + + const connectorExistsFromCredential = checkConnectorsExist(gmailConnectors); + + const gmailPublicUploadedCredential: + | Credential + | undefined = credentialsData.find( + (credential) => + credential.credential_json?.google_tokens && + credential.admin_public && + credential.source === "gmail" && + credential.credential_json.authentication_method !== "oauth_interactive" + ); + const gmailServiceAccountCredential: | Credential | undefined = credentialsData.find( @@ -118,6 +154,13 @@ export const GmailMain = () => { (connectorIndexingStatus) => connectorIndexingStatus.source === "gmail" ); + const connectorExists = + connectorExistsFromCredential || gmailConnectorIndexingStatuses.length > 0; + + const hasUploadedCredentials = + Boolean(appCredentialData?.client_id) || + Boolean(serviceAccountKeyData?.service_account_email); + return ( <> {popup} @@ -129,21 +172,22 @@ export const GmailMain = () => { appCredentialData={appCredentialData} serviceAccountCredentialData={serviceAccountKeyData} isAdmin={isAdmin} + onSuccess={handleRefresh} /> - {isAdmin && ( + {isAdmin && hasUploadedCredentials && ( <> Step 2: Authenticate with Onyx 0} + connectorExists={connectorExists} user={user} /> diff --git a/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx b/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx index 41e3f9bef..df5128dce 100644 --- a/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx +++ b/web/src/app/admin/embeddings/EmbeddingModelSelectionForm.tsx @@ -103,42 +103,6 @@ export function EmbeddingModelSelection({ { refreshInterval: 5000 } // 5 seconds ); - const { data: connectors } = useSWR[]>( - "/api/manage/connector", - errorHandlingFetcher, - { refreshInterval: 5000 } // 5 seconds - ); - - const onConfirmSelection = async (model: EmbeddingModelDescriptor) => { - const response = await fetch( - "/api/search-settings/set-new-search-settings", - { - method: "POST", - body: JSON.stringify({ ...model, index_name: null }), - headers: { - "Content-Type": "application/json", - }, - } - ); - if (response.ok) { - setShowTentativeModel(null); - mutate("/api/search-settings/get-secondary-search-settings"); - if (!connectors || !connectors.length) { - setShowAddConnectorPopup(true); - } - } else { - alert(`Failed to update embedding model - ${await response.text()}`); - } - }; - - const onSelectOpenSource = async (model: HostedEmbeddingModel) => { - if (selectedProvider?.model_name === INVALID_OLD_MODEL) { - await onConfirmSelection(model); - } else { - setShowTentativeOpenProvider(model); - } - }; - return (
{alreadySelectedModel && ( @@ -270,7 +234,9 @@ export function EmbeddingModelSelection({ {modelTab == "open" && ( { + setShowTentativeOpenProvider(model); + }} /> )} diff --git a/web/src/app/admin/embeddings/RerankingFormPage.tsx b/web/src/app/admin/embeddings/RerankingFormPage.tsx index 236db2a98..cf0d53933 100644 --- a/web/src/app/admin/embeddings/RerankingFormPage.tsx +++ b/web/src/app/admin/embeddings/RerankingFormPage.tsx @@ -30,6 +30,10 @@ interface RerankingDetailsFormProps { originalRerankingDetails: RerankingDetails; modelTab: "open" | "cloud" | null; setModelTab: Dispatch>; + onValidationChange?: ( + isValid: boolean, + errors: Record + ) => void; } const RerankingDetailsForm = forwardRef< @@ -43,6 +47,7 @@ const RerankingDetailsForm = forwardRef< currentRerankingDetails, modelTab, setModelTab, + onValidationChange, }, ref ) => { @@ -55,26 +60,78 @@ const RerankingDetailsForm = forwardRef< const combinedSettings = useContext(SettingsContext); const gpuEnabled = combinedSettings?.settings.gpu_enabled; + // Define the validation schema + const validationSchema = Yup.object().shape({ + rerank_model_name: Yup.string().nullable(), + rerank_provider_type: Yup.mixed() + .nullable() + .oneOf(Object.values(RerankerProvider)) + .optional(), + rerank_api_key: Yup.string() + .nullable() + .test( + "required-if-cohere", + "API Key is required for Cohere reranking", + function (value) { + const { rerank_provider_type } = this.parent; + return ( + rerank_provider_type !== RerankerProvider.COHERE || + (value !== null && value !== "") + ); + } + ), + rerank_api_url: Yup.string() + .url("Must be a valid URL") + .matches(/^https?:\/\//, "URL must start with http:// or https://") + .nullable() + .test( + "required-if-litellm", + "API URL is required for LiteLLM reranking", + function (value) { + const { rerank_provider_type } = this.parent; + return ( + rerank_provider_type !== RerankerProvider.LITELLM || + (value !== null && value !== "") + ); + } + ), + }); + return ( () - .nullable() - .oneOf(Object.values(RerankerProvider)) - .optional(), - api_key: Yup.string().nullable(), - num_rerank: Yup.number().min(1, "Must be at least 1"), - rerank_api_url: Yup.string() - .url("Must be a valid URL") - .matches(/^https?:\/\//, "URL must start with http:// or https://") - .nullable(), - })} + validationSchema={validationSchema} onSubmit={async (_, { setSubmitting }) => { setSubmitting(false); }} + validate={(values) => { + // Update parent component with values + setRerankingDetails(values); + + // Run validation and report errors + if (onValidationChange) { + // We'll return an empty object here since Yup will handle the actual validation + // But we need to check if there are any validation errors + const errors: Record = {}; + try { + // Manually validate against the schema + validationSchema.validateSync(values, { abortEarly: false }); + onValidationChange(true, {}); + } catch (validationError) { + if (validationError instanceof Yup.ValidationError) { + validationError.inner.forEach((err) => { + if (err.path) { + errors[err.path] = err.message; + } + }); + onValidationChange(false, errors); + } + } + } + + return {}; // Return empty object as Formik will handle the errors + }} enableReinitialize={true} > {({ values, setFieldValue, resetForm }) => { diff --git a/web/src/app/admin/embeddings/interfaces.ts b/web/src/app/admin/embeddings/interfaces.ts index 2a53acdb1..cc6294548 100644 --- a/web/src/app/admin/embeddings/interfaces.ts +++ b/web/src/app/admin/embeddings/interfaces.ts @@ -20,6 +20,11 @@ export enum RerankerProvider { LITELLM = "litellm", } +export enum EmbeddingPrecision { + FLOAT = "float", + BFLOAT16 = "bfloat16", +} + export interface AdvancedSearchConfiguration { index_name: string | null; multipass_indexing: boolean; @@ -27,12 +32,15 @@ export interface AdvancedSearchConfiguration { disable_rerank_for_streaming: boolean; api_url: string | null; num_rerank: number; + embedding_precision: EmbeddingPrecision; + reduced_dimension: number | null; } export interface SavedSearchSettings extends RerankingDetails, AdvancedSearchConfiguration { provider_type: EmbeddingProvider | null; + background_reindex_enabled: boolean; } export interface RerankingModel { diff --git a/web/src/app/admin/embeddings/modals/InstantSwitchConfirmModal.tsx b/web/src/app/admin/embeddings/modals/InstantSwitchConfirmModal.tsx new file mode 100644 index 000000000..ce4bd3400 --- /dev/null +++ b/web/src/app/admin/embeddings/modals/InstantSwitchConfirmModal.tsx @@ -0,0 +1,37 @@ +import { Modal } from "@/components/Modal"; +import { Button } from "@/components/ui/button"; + +interface InstantSwitchConfirmModalProps { + onClose: () => void; + onConfirm: () => void; +} + +export const InstantSwitchConfirmModal = ({ + onClose, + onConfirm, +}: InstantSwitchConfirmModalProps) => { + return ( + + <> +
+ Instant switching will immediately change the embedding model without + re-indexing. Searches will be over a partial set of documents + (starting with 0 documents) until re-indexing is complete. +
+
+ This is not reversible. +
+
+ + +
+ +
+ ); +}; diff --git a/web/src/app/admin/embeddings/modals/ModelSelectionModal.tsx b/web/src/app/admin/embeddings/modals/ModelSelectionModal.tsx index a8c59afdd..12b087d52 100644 --- a/web/src/app/admin/embeddings/modals/ModelSelectionModal.tsx +++ b/web/src/app/admin/embeddings/modals/ModelSelectionModal.tsx @@ -51,9 +51,10 @@ export function ModelSelectionConfirmationModal({ )} -
- +
diff --git a/web/src/app/admin/embeddings/modals/SelectModelModal.tsx b/web/src/app/admin/embeddings/modals/SelectModelModal.tsx index 7b9347a37..ccc4f3f04 100644 --- a/web/src/app/admin/embeddings/modals/SelectModelModal.tsx +++ b/web/src/app/admin/embeddings/modals/SelectModelModal.tsx @@ -21,15 +21,14 @@ export function SelectModelModal({ >
- You're selecting a new embedding model, {model.model_name}. If - you update to this model, you will need to undergo a complete - re-indexing. -
- Are you sure? + You're selecting a new embedding model, {model.model_name} + . If you update to this model, you will need to undergo a complete + re-indexing. Are you sure?
-
- +
diff --git a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx index bea80322e..b5003f835 100644 --- a/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/AdvancedEmbeddingFormPage.tsx @@ -3,13 +3,15 @@ import { Formik, Form, FormikProps, FieldArray, Field } from "formik"; import * as Yup from "yup"; import { TrashIcon } from "@/components/icons/icons"; import { FaPlus } from "react-icons/fa"; -import { AdvancedSearchConfiguration } from "../interfaces"; +import { AdvancedSearchConfiguration, EmbeddingPrecision } from "../interfaces"; import { BooleanFormField, Label, SubLabel, + SelectorFormField, } from "@/components/admin/connectors/Field"; import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput"; +import { StringOrNumberOption } from "@/components/Dropdown"; interface AdvancedEmbeddingFormPageProps { updateAdvancedEmbeddingDetails: ( @@ -17,102 +19,207 @@ interface AdvancedEmbeddingFormPageProps { value: any ) => void; advancedEmbeddingDetails: AdvancedSearchConfiguration; + embeddingProviderType: string | null; + onValidationChange?: ( + isValid: boolean, + errors: Record + ) => void; } +// Options for embedding precision based on EmbeddingPrecision enum +const embeddingPrecisionOptions: StringOrNumberOption[] = [ + { name: EmbeddingPrecision.BFLOAT16, value: EmbeddingPrecision.BFLOAT16 }, + { name: EmbeddingPrecision.FLOAT, value: EmbeddingPrecision.FLOAT }, +]; + const AdvancedEmbeddingFormPage = forwardRef< FormikProps, AdvancedEmbeddingFormPageProps ->(({ updateAdvancedEmbeddingDetails, advancedEmbeddingDetails }, ref) => { - return ( -
- { - setSubmitting(false); - }} - validate={(values) => { - // Call updateAdvancedEmbeddingDetails for each changed field - Object.entries(values).forEach(([key, value]) => { - updateAdvancedEmbeddingDetails( - key as keyof AdvancedSearchConfiguration, - value - ); - }); - }} - enableReinitialize={true} - > - {({ values }) => ( -
- - {({ push, remove }) => ( -
- +>( + ( + { + updateAdvancedEmbeddingDetails, + advancedEmbeddingDetails, + embeddingProviderType, + onValidationChange, + }, + ref + ) => { + return ( +
+ value === null || value === undefined || value >= 256 + ) + .test( + "openai", + "Reduced Dimensions is only supported for OpenAI embedding models", + (value) => { + return embeddingProviderType === "openai" || value === null; + } + ), + })} + onSubmit={async (_, { setSubmitting }) => { + setSubmitting(false); + }} + validate={(values) => { + // Call updateAdvancedEmbeddingDetails for each changed field + Object.entries(values).forEach(([key, value]) => { + updateAdvancedEmbeddingDetails( + key as keyof AdvancedSearchConfiguration, + value + ); + }); - Add additional languages to the search. - {values.multilingual_expansion.map( - (_: any, index: number) => ( -
- = {}; + try { + // Manually validate against the schema + Yup.object() + .shape({ + multilingual_expansion: Yup.array().of(Yup.string()), + multipass_indexing: Yup.boolean(), + disable_rerank_for_streaming: Yup.boolean(), + num_rerank: Yup.number() + .required("Number of results to rerank is required") + .min(1, "Must be at least 1"), + embedding_precision: Yup.string().nullable(), + reduced_dimension: Yup.number() + .nullable() + .test( + "positive", + "Must be larger than or equal to 256", + (value) => + value === null || value === undefined || value >= 256 + ) + .test( + "openai", + "Reduced Dimensions is only supported for OpenAI embedding models", + (value) => { + return ( + embeddingProviderType === "openai" || value === null + ); + } + ), + }) + .validateSync(values, { abortEarly: false }); + onValidationChange(true, {}); + } catch (validationError) { + if (validationError instanceof Yup.ValidationError) { + validationError.inner.forEach((err) => { + if (err.path) { + errors[err.path] = err.message; + } + }); + onValidationChange(false, errors); + } + } + } + + return {}; // Return empty object as Formik will handle the errors + }} + enableReinitialize={true} + > + {({ values }) => ( + + + {({ push, remove }) => ( +
+ + + Add additional languages to the search. + {values.multilingual_expansion.map( + (_: any, index: number) => ( +
+ - -
- ) - )} - +
+ ) + )} + -
- )} - + > + + Add Language + +
+ )} + - - - - - )} - -
- ); -}); + + + + + + + + + )} +
+
+ ); + } +); export default AdvancedEmbeddingFormPage; AdvancedEmbeddingFormPage.displayName = "AdvancedEmbeddingFormPage"; diff --git a/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx b/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx index 78f603cf5..91260d491 100644 --- a/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx +++ b/web/src/app/admin/embeddings/pages/EmbeddingFormPage.tsx @@ -3,10 +3,16 @@ import { usePopup } from "@/components/admin/connectors/Popup"; import { HealthCheckBanner } from "@/components/health/healthcheck"; import { EmbeddingModelSelection } from "../EmbeddingModelSelectionForm"; -import { useCallback, useEffect, useMemo, useState } from "react"; +import { useCallback, useEffect, useMemo, useState, useRef } from "react"; import Text from "@/components/ui/text"; import { Button } from "@/components/ui/button"; -import { ArrowLeft, ArrowRight, WarningCircle } from "@phosphor-icons/react"; +import { + ArrowLeft, + ArrowRight, + WarningCircle, + CaretDown, + Warning, +} from "@phosphor-icons/react"; import { CloudEmbeddingModel, EmbeddingProvider, @@ -19,16 +25,35 @@ import { ThreeDotsLoader } from "@/components/Loading"; import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage"; import { AdvancedSearchConfiguration, + EmbeddingPrecision, RerankingDetails, SavedSearchSettings, } from "../interfaces"; import RerankingDetailsForm from "../RerankingFormPage"; import { useEmbeddingFormContext } from "@/components/context/EmbeddingContext"; import { Modal } from "@/components/Modal"; +import { InstantSwitchConfirmModal } from "../modals/InstantSwitchConfirmModal"; import { useRouter } from "next/navigation"; import CardSection from "@/components/admin/CardSection"; import { combineSearchSettings } from "./utils"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "@/components/ui/tooltip"; + +enum ReindexType { + REINDEX = "reindex", + INSTANT = "instant", +} export default function EmbeddingForm() { const { formStep, nextFormStep, prevFormStep } = useEmbeddingFormContext(); @@ -43,6 +68,8 @@ export default function EmbeddingForm() { disable_rerank_for_streaming: false, api_url: null, num_rerank: 0, + embedding_precision: EmbeddingPrecision.FLOAT, + reduced_dimension: null, }); const [rerankingDetails, setRerankingDetails] = useState({ @@ -52,6 +79,19 @@ export default function EmbeddingForm() { rerank_api_url: null, }); + const [reindexType, setReindexType] = useState( + ReindexType.REINDEX + ); + + const [formErrors, setFormErrors] = useState>({}); + const [isFormValid, setIsFormValid] = useState(true); + const [rerankFormErrors, setRerankFormErrors] = useState< + Record + >({}); + const [isRerankFormValid, setIsRerankFormValid] = useState(true); + const advancedFormRef = useRef(null); + const rerankFormRef = useRef(null); + const updateAdvancedEmbeddingDetails = ( key: keyof AdvancedSearchConfiguration, value: any @@ -82,6 +122,8 @@ export default function EmbeddingForm() { }; const [displayPoorModelName, setDisplayPoorModelName] = useState(true); const [showPoorModel, setShowPoorModel] = useState(false); + const [showInstantSwitchConfirm, setShowInstantSwitchConfirm] = + useState(false); const [modelTab, setModelTab] = useState<"open" | "cloud" | null>(null); const { @@ -115,6 +157,8 @@ export default function EmbeddingForm() { searchSettings.disable_rerank_for_streaming, num_rerank: searchSettings.num_rerank, api_url: null, + embedding_precision: searchSettings.embedding_precision, + reduced_dimension: searchSettings.reduced_dimension, }); setRerankingDetails({ @@ -146,17 +190,14 @@ export default function EmbeddingForm() { } }, [currentEmbeddingModel]); - const handleReindex = async () => { - const update = await updateSearch(); - if (update) { - await onConfirm(); - } - }; - const needsReIndex = currentEmbeddingModel != selectedProvider || searchSettings?.multipass_indexing != - advancedEmbeddingDetails.multipass_indexing; + advancedEmbeddingDetails.multipass_indexing || + searchSettings?.embedding_precision != + advancedEmbeddingDetails.embedding_precision || + searchSettings?.reduced_dimension != + advancedEmbeddingDetails.reduced_dimension; const updateSearch = useCallback(async () => { if (!selectedProvider) { @@ -166,18 +207,44 @@ export default function EmbeddingForm() { selectedProvider, advancedEmbeddingDetails, rerankingDetails, - selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null + selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null, + reindexType === ReindexType.REINDEX ); const response = await updateSearchSettings(searchSettings); if (response.ok) { return true; } else { - setPopup({ message: "Failed to update search settings", type: "error" }); + setPopup({ + message: "Failed to update search settings", + type: "error", + }); return false; } }, [selectedProvider, advancedEmbeddingDetails, rerankingDetails, setPopup]); + const handleValidationChange = useCallback( + (isValid: boolean, errors: Record) => { + setIsFormValid(isValid); + setFormErrors(errors); + }, + [] + ); + + const handleRerankValidationChange = useCallback( + (isValid: boolean, errors: Record) => { + setIsRerankFormValid(isValid); + setRerankFormErrors(errors); + }, + [] + ); + + // Combine validation states for both forms + const isOverallFormValid = isFormValid && isRerankFormValid; + const combinedFormErrors = useMemo(() => { + return { ...formErrors, ...rerankFormErrors }; + }, [formErrors, rerankFormErrors]); + const ReIndexingButton = useMemo(() => { const ReIndexingButtonComponent = ({ needsReIndex, @@ -186,47 +253,204 @@ export default function EmbeddingForm() { }) => { return needsReIndex ? (
- -
- -
-

Needs re-indexing due to:

-
    - {currentEmbeddingModel != selectedProvider && ( -
  • Changed embedding provider
  • - )} - {searchSettings?.multipass_indexing != - advancedEmbeddingDetails.multipass_indexing && ( -
  • Multipass indexing modification
  • - )} -
-
+
+ + + + + + + { + setReindexType(ReindexType.REINDEX); + }} + > + + + + (Recommended) Re-index + + +

+ Re-runs all connectors in the background before + switching over. Takes longer but ensures no + degredation of search during the switch. +

+
+
+
+
+ { + setReindexType(ReindexType.INSTANT); + }} + > + + + + Instant Switch + + +

+ Immediately switches to new settings without + re-indexing. Searches will be degraded until the + re-indexing is complete. +

+
+
+
+
+
+
+ {isOverallFormValid && ( +
+ +
+

Needs re-indexing due to:

+
    + {currentEmbeddingModel != selectedProvider && ( +
  • Changed embedding provider
  • + )} + {searchSettings?.multipass_indexing != + advancedEmbeddingDetails.multipass_indexing && ( +
  • Multipass indexing modification
  • + )} + {searchSettings?.embedding_precision != + advancedEmbeddingDetails.embedding_precision && ( +
  • Embedding precision modification
  • + )} + {searchSettings?.reduced_dimension != + advancedEmbeddingDetails.reduced_dimension && ( +
  • Reduced dimension modification
  • + )} +
+
+
+ )} + {!isOverallFormValid && + Object.keys(combinedFormErrors).length > 0 && ( +
+ +
+

Validation Errors:

+
    + {Object.entries(combinedFormErrors).map( + ([field, error]) => ( +
  • + {field}: {error} +
  • + ) + )} +
+
+
+ )}
) : ( - +
+ + {!isOverallFormValid && + Object.keys(combinedFormErrors).length > 0 && ( +
+ +
+

+ Validation Errors: +

+
    + {Object.entries(combinedFormErrors).map( + ([field, error]) => ( +
  • {error}
  • + ) + )} +
+
+
+ )} +
); }; ReIndexingButtonComponent.displayName = "ReIndexingButton"; return ReIndexingButtonComponent; - }, [needsReIndex, updateSearch]); + }, [needsReIndex, reindexType, isOverallFormValid, combinedFormErrors]); if (!selectedProvider) { return ; @@ -246,7 +470,7 @@ export default function EmbeddingForm() { router.push("/admin/configuration/search?message=search-settings"); }; - const onConfirm = async () => { + const handleReIndex = async () => { if (!selectedProvider) { return; } @@ -260,7 +484,8 @@ export default function EmbeddingForm() { rerankingDetails, selectedProvider.provider_type ?.toLowerCase() - .split(" ")[0] as EmbeddingProvider | null + .split(" ")[0] as EmbeddingProvider | null, + reindexType === ReindexType.REINDEX ); } else { // This is a locally hosted model @@ -268,7 +493,8 @@ export default function EmbeddingForm() { selectedProvider, advancedEmbeddingDetails, rerankingDetails, - null + null, + reindexType === ReindexType.REINDEX ); } @@ -381,6 +607,17 @@ export default function EmbeddingForm() { )} + {showInstantSwitchConfirm && ( + setShowInstantSwitchConfirm(false)} + onConfirm={() => { + setShowInstantSwitchConfirm(false); + handleReIndex(); + navigateToEmbeddingPage("search settings"); + }} + /> + )} + {formStep == 1 && ( <>

@@ -395,6 +632,7 @@ export default function EmbeddingForm() { @@ -444,8 +683,11 @@ export default function EmbeddingForm() { diff --git a/web/src/app/admin/embeddings/pages/OpenEmbeddingPage.tsx b/web/src/app/admin/embeddings/pages/OpenEmbeddingPage.tsx index e84781f37..a9c2f91d9 100644 --- a/web/src/app/admin/embeddings/pages/OpenEmbeddingPage.tsx +++ b/web/src/app/admin/embeddings/pages/OpenEmbeddingPage.tsx @@ -16,7 +16,7 @@ export default function OpenEmbeddingPage({ onSelectOpenSource, selectedProvider, }: { - onSelectOpenSource: (model: HostedEmbeddingModel) => Promise; + onSelectOpenSource: (model: HostedEmbeddingModel) => void; selectedProvider: HostedEmbeddingModel | CloudEmbeddingModel; }) { const [configureModel, setConfigureModel] = useState(false); diff --git a/web/src/app/admin/embeddings/pages/utils.ts b/web/src/app/admin/embeddings/pages/utils.ts index 039b14242..4889091a2 100644 --- a/web/src/app/admin/embeddings/pages/utils.ts +++ b/web/src/app/admin/embeddings/pages/utils.ts @@ -63,12 +63,14 @@ export const combineSearchSettings = ( selectedProvider: CloudEmbeddingProvider | HostedEmbeddingModel, advancedEmbeddingDetails: AdvancedSearchConfiguration, rerankingDetails: RerankingDetails, - provider_type: EmbeddingProvider | null + provider_type: EmbeddingProvider | null, + background_reindex_enabled: boolean ): SavedSearchSettings => { return { ...selectedProvider, ...advancedEmbeddingDetails, ...rerankingDetails, provider_type: provider_type, + background_reindex_enabled, }; }; diff --git a/web/src/app/admin/users/page.tsx b/web/src/app/admin/users/page.tsx index 5a0d3009f..78d7933e5 100644 --- a/web/src/app/admin/users/page.tsx +++ b/web/src/app/admin/users/page.tsx @@ -19,6 +19,8 @@ import BulkAdd from "@/components/admin/users/BulkAdd"; import Text from "@/components/ui/text"; import { InvitedUserSnapshot } from "@/lib/types"; import { SearchBar } from "@/components/search/SearchBar"; +import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal"; +import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants"; const UsersTables = ({ q, @@ -130,6 +132,13 @@ const AddUserButton = ({ setPopup: (spec: PopupSpec) => void; }) => { const [modal, setModal] = useState(false); + const [showConfirmation, setShowConfirmation] = useState(false); + + const { data: invitedUsers } = useSWR( + "/api/manage/users/invited", + errorHandlingFetcher + ); + const onSuccess = () => { mutate( (key) => typeof key === "string" && key.startsWith("/api/manage/users") @@ -140,6 +149,7 @@ const AddUserButton = ({ type: "success", }); }; + const onFailure = async (res: Response) => { const error = (await res.json()).detail; setPopup({ @@ -147,15 +157,45 @@ const AddUserButton = ({ type: "error", }); }; + + const handleInviteClick = () => { + if ( + !NEXT_PUBLIC_CLOUD_ENABLED && + invitedUsers && + invitedUsers.length === 0 + ) { + setShowConfirmation(true); + } else { + setModal(true); + } + }; + + const handleConfirmFirstInvite = () => { + setShowConfirmation(false); + setModal(true); + }; + return ( <> - + {showConfirmation && ( + setShowConfirmation(false)} + onSubmit={handleConfirmFirstInvite} + additionalDetails="After inviting the first user, only invited users will be able to join this platform. This is a security measure to control access to your instance." + actionButtonText="Continue" + variant="action" + /> + )} + {modal && ( setModal(false)}>
diff --git a/web/src/app/assistants/mine/AssistantModal.tsx b/web/src/app/assistants/mine/AssistantModal.tsx index b5876c36d..c816e5172 100644 --- a/web/src/app/assistants/mine/AssistantModal.tsx +++ b/web/src/app/assistants/mine/AssistantModal.tsx @@ -61,13 +61,9 @@ const useAssistantFilter = () => { interface AssistantModalProps { hideModal: () => void; - modalHeight?: string; } -export function AssistantModal({ - hideModal, - modalHeight, -}: AssistantModalProps) { +export function AssistantModal({ hideModal }: AssistantModalProps) { const { assistants, pinnedAssistants } = useAssistants(); const { assistantFilters, toggleAssistantFilter } = useAssistantFilter(); const router = useRouter(); diff --git a/web/src/app/auth/error/layout.tsx b/web/src/app/auth/error/layout.tsx new file mode 100644 index 000000000..f9e9b6315 --- /dev/null +++ b/web/src/app/auth/error/layout.tsx @@ -0,0 +1,16 @@ +export default function AuthErrorLayout({ + children, +}: { + children: React.ReactNode; +}) { + // Log error to console for debugging + console.error( + "Authentication error page was accessed - this should not happen in normal flow" + ); + + // In a production environment, you might want to send this to your error tracking service + // For example, if using a service like Sentry: + // captureException(new Error("Authentication error page was accessed unexpectedly")); + + return <>{children}; +} diff --git a/web/src/app/auth/error/page.tsx b/web/src/app/auth/error/page.tsx index 3bee49841..ec2e01e7b 100644 --- a/web/src/app/auth/error/page.tsx +++ b/web/src/app/auth/error/page.tsx @@ -4,6 +4,7 @@ import AuthFlowContainer from "@/components/auth/AuthFlowContainer"; import { Button } from "@/components/ui/button"; import Link from "next/link"; import { FiLogIn } from "react-icons/fi"; +import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants"; const Page = () => { return ( @@ -15,19 +16,21 @@ const Page = () => {

We encountered an issue while attempting to log you in.

-
-

Possible Issues:

+
+

+ Possible Issues: +

    -
  • -
    +
  • +
    Incorrect or expired login credentials
  • -
  • -
    +
  • +
    Temporary authentication system disruption
  • -
  • -
    +
  • +
    Account access restrictions or permissions
@@ -41,6 +44,12 @@ const Page = () => {

We recommend trying again. If you continue to experience problems, please reach out to your system administrator for assistance. + {NEXT_PUBLIC_CLOUD_ENABLED && ( + + A member of our team has been automatically notified about this + issue. + + )}

diff --git a/web/src/app/auth/login/EmailPasswordForm.tsx b/web/src/app/auth/login/EmailPasswordForm.tsx index 6f7a48cc9..f63eefe38 100644 --- a/web/src/app/auth/login/EmailPasswordForm.tsx +++ b/web/src/app/auth/login/EmailPasswordForm.tsx @@ -36,19 +36,25 @@ export function EmailPasswordForm({ {popup} value.toLowerCase()), password: Yup.string().required(), })} onSubmit={async (values) => { + // Ensure email is lowercase + const email = values.email.toLowerCase(); + if (isSignup) { // login is fast, no need to show a spinner setIsWorking(true); const response = await basicSignup( - values.email, + email, values.password, referralSource ); @@ -75,10 +81,10 @@ export function EmailPasswordForm({ } } - const loginResponse = await basicLogin(values.email, values.password); + const loginResponse = await basicLogin(email, values.password); if (loginResponse.ok) { if (isSignup && shouldVerify) { - await requestEmailVerification(values.email); + await requestEmailVerification(email); // Use window.location.href to force a full page reload, // ensuring app re-initializes with the new state (including // server-side provider values) diff --git a/web/src/app/chat/folders/FolderList.tsx b/web/src/app/chat/folders/FolderList.tsx index 1d63f943b..61e532f83 100644 --- a/web/src/app/chat/folders/FolderList.tsx +++ b/web/src/app/chat/folders/FolderList.tsx @@ -168,7 +168,7 @@ const FolderItem = ({ }; const folders = folder.chat_sessions.sort((a, b) => { - return a.time_created.localeCompare(b.time_created); + return a.time_updated.localeCompare(b.time_updated); }); // Determine whether to show the trash can icon diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index 6f614c7d2..f5daad62b 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -70,6 +70,7 @@ export interface ChatSession { name: string; persona_id: number; time_created: string; + time_updated: string; shared_status: ChatSessionSharedStatus; folder_id: number | null; current_alternate_model: string; @@ -123,6 +124,7 @@ export interface BackendChatSession { persona_icon_shape: number | null; messages: BackendMessage[]; time_created: string; + time_updated: string; shared_status: ChatSessionSharedStatus; current_temperature_override: number | null; current_alternate_model?: string; diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index 62c908d55..a95c64d14 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -48,10 +48,10 @@ export function getChatRetentionInfo( ): ChatRetentionInfo { // If `maximum_chat_retention_days` isn't set- never display retention warning. const chatRetentionDays = settings.maximum_chat_retention_days || 10000; - const createdDate = new Date(chatSession.time_created); + const updatedDate = new Date(chatSession.time_updated); const today = new Date(); const daysFromCreation = Math.ceil( - (today.getTime() - createdDate.getTime()) / (1000 * 3600 * 24) + (today.getTime() - updatedDate.getTime()) / (1000 * 3600 * 24) ); const daysUntilExpiration = chatRetentionDays - daysFromCreation; const showRetentionWarning = @@ -419,7 +419,7 @@ export function groupSessionsByDateRange(chatSessions: ChatSession[]) { }; chatSessions.forEach((chatSession) => { - const chatSessionDate = new Date(chatSession.time_created); + const chatSessionDate = new Date(chatSession.time_updated); const diffTime = today.getTime() - chatSessionDate.getTime(); const diffDays = diffTime / (1000 * 3600 * 24); // Convert time difference to days @@ -501,6 +501,7 @@ export function processRawChatHistory( sub_questions: subQuestions, isImprovement: (messageInfo.refined_answer_improvement as unknown as boolean) || false, + is_agentic: messageInfo.is_agentic, }; messages.set(messageInfo.message_id, message); diff --git a/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx b/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx index 73fbc3639..a59249e7c 100644 --- a/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx +++ b/web/src/app/chat/shared/[chatId]/SharedChatDisplay.tsx @@ -206,7 +206,7 @@ export function SharedChatDisplay({ {chatSession.description || `Unnamed Chat`}

- {humanReadableFormat(chatSession.time_created)} + {humanReadableFormat(chatSession.time_updated)}

{title}

-

+

{value}

diff --git a/web/src/components/Button.tsx b/web/src/components/Button.tsx index 5672469f1..ea6faadc3 100644 --- a/web/src/components/Button.tsx +++ b/web/src/components/Button.tsx @@ -3,7 +3,6 @@ interface Props { onClick?: React.MouseEventHandler; type?: "button" | "submit" | "reset"; disabled?: boolean; - fullWidth?: boolean; className?: string; } @@ -12,14 +11,12 @@ export const Button = ({ onClick, type = "submit", disabled = false, - fullWidth = false, className = "", }: Props) => { return (
- - {access_type.value === "sync" && isAutoSyncSupported && ( )} diff --git a/web/src/components/admin/connectors/Field.tsx b/web/src/components/admin/connectors/Field.tsx index 851e3b380..b302073d9 100644 --- a/web/src/components/admin/connectors/Field.tsx +++ b/web/src/components/admin/connectors/Field.tsx @@ -51,13 +51,13 @@ export function Label({ className?: string; }) { return ( -
{children} -
+ ); } @@ -132,7 +132,6 @@ export function TextFormField({ label, subtext, placeholder, - value, type = "text", optional, includeRevert, @@ -157,7 +156,6 @@ export function TextFormField({ vertical, className, }: { - value?: string; // Escape hatch for setting the value of the field - conflicts with Formik name: string; removeLabel?: boolean; label: string; @@ -253,7 +251,6 @@ export function TextFormField({ min={min} as={isTextArea ? "textarea" : "input"} type={type} - defaultValue={value} data-testid={name} name={name} id={name} @@ -689,7 +686,7 @@ export function SelectorFormField({ defaultValue, tooltip, includeReset = false, - fontSize = "sm", + fontSize = "md", small = false, }: SelectorFormFieldProps) { const [field] = useField(name); diff --git a/web/src/components/context/AppProvider.tsx b/web/src/components/context/AppProvider.tsx index 373085d1a..c7f088922 100644 --- a/web/src/components/context/AppProvider.tsx +++ b/web/src/components/context/AppProvider.tsx @@ -6,9 +6,6 @@ import { SettingsProvider } from "../settings/SettingsProvider"; import { AssistantsProvider } from "./AssistantsContext"; import { Persona } from "@/app/admin/assistants/interfaces"; import { User } from "@/lib/types"; -import { fetchChatData } from "@/lib/chat/fetchChatData"; -import { ChatProvider } from "./ChatContext"; -import { redirect } from "next/navigation"; interface AppProviderProps { children: React.ReactNode; diff --git a/web/src/components/credentials/CredentialFields.tsx b/web/src/components/credentials/CredentialFields.tsx index d23b30834..70e6f8809 100644 --- a/web/src/components/credentials/CredentialFields.tsx +++ b/web/src/components/credentials/CredentialFields.tsx @@ -130,6 +130,7 @@ interface BooleanFormFieldProps { small?: boolean; alignTop?: boolean; noLabel?: boolean; + disabled?: boolean; onChange?: (e: React.ChangeEvent) => void; } @@ -141,6 +142,7 @@ export const AdminBooleanFormField = ({ small, checked, alignTop, + disabled = false, onChange, }: BooleanFormFieldProps) => { const [field, meta, helpers] = useField(name); @@ -152,6 +154,7 @@ export const AdminBooleanFormField = ({ type="checkbox" {...field} checked={Boolean(field.value)} + disabled={disabled} onChange={(e) => { helpers.setValue(e.target.checked); }} diff --git a/web/src/components/embedding/ReindexingProgressTable.tsx b/web/src/components/embedding/ReindexingProgressTable.tsx index ad26c595c..186f3444a 100644 --- a/web/src/components/embedding/ReindexingProgressTable.tsx +++ b/web/src/components/embedding/ReindexingProgressTable.tsx @@ -29,6 +29,7 @@ export function ReindexingProgressTable({ Connector Name Status Docs Re-Indexed + diff --git a/web/src/components/embedding/interfaces.tsx b/web/src/components/embedding/interfaces.tsx index e03d163f9..1ac3f5da6 100644 --- a/web/src/components/embedding/interfaces.tsx +++ b/web/src/components/embedding/interfaces.tsx @@ -55,6 +55,7 @@ export interface EmbeddingModelDescriptor { api_version?: string | null; deployment_name?: string | null; index_name: string | null; + background_reindex_enabled?: boolean; } export interface CloudEmbeddingModel extends EmbeddingModelDescriptor { diff --git a/web/src/components/ui/alert.tsx b/web/src/components/ui/alert.tsx index ee668e2f4..ca648c798 100644 --- a/web/src/components/ui/alert.tsx +++ b/web/src/components/ui/alert.tsx @@ -11,11 +11,11 @@ const alertVariants = cva( broken: "border-red-500/50 text-red-500 dark:border-red-500 [&>svg]:text-red-500 dark:border-red-900/50 dark:text-red-100 dark:dark:border-red-900 dark:[&>svg]:text-red-700 bg-red-50 dark:bg-red-950", ark: "border-amber-500/50 text-amber-500 dark:border-amber-500 [&>svg]:text-amber-500 dark:border-amber-900/50 dark:text-amber-900 dark:dark:border-amber-900 dark:[&>svg]:text-amber-900 bg-amber-50 dark:bg-amber-950", - info: "border-black/50 dark:border-black dark:border-black/50 dark:dark:border-black", + info: "border-[#fff]/50 dark:border-[#fff] dark:border-[#fff]/50 dark:dark:border-[#fff]", default: "bg-neutral-50 text-neutral-darker dark:bg-neutral-950 dark:text-text", destructive: - "border-red-500/50 text-red-500 dark:border-red-500 [&>svg]:text-red-500 dark:border-red-900/50 dark:text-red-900 dark:dark:border-red-900 dark:[&>svg]:text-red-900", + "border-red-500/50 text-red-500 dark:border-red-500 [&>svg]:text-red-500 dark:border-red-900/50 dark:text-red-600 dark:dark:border-red-900 dark:[&>svg]:text-red-900", }, }, defaultVariants: { diff --git a/web/src/lib/assistants/utils.ts b/web/src/lib/assistants/utils.ts index 14c3eabc2..5c3ab5adc 100644 --- a/web/src/lib/assistants/utils.ts +++ b/web/src/lib/assistants/utils.ts @@ -131,7 +131,8 @@ export function filterAssistants( if (!hasAnyConnectors) { filteredAssistants = filteredAssistants.filter( - (assistant) => assistant.num_chunks === 0 + (assistant) => + assistant.num_chunks === 0 || assistant.document_sets.length > 0 ); } diff --git a/web/src/lib/chat/fetchChatData.ts b/web/src/lib/chat/fetchChatData.ts index 88a646437..91b98671f 100644 --- a/web/src/lib/chat/fetchChatData.ts +++ b/web/src/lib/chat/fetchChatData.ts @@ -148,7 +148,7 @@ export async function fetchChatData(searchParams: { chatSessions.sort( (a, b) => - new Date(b.time_created).getTime() - new Date(a.time_created).getTime() + new Date(b.time_updated).getTime() - new Date(a.time_updated).getTime() ); let documentSets: DocumentSet[] = []; diff --git a/web/src/lib/connectors/connectors.tsx b/web/src/lib/connectors/connectors.tsx index 37cb373b9..99ccde947 100644 --- a/web/src/lib/connectors/connectors.tsx +++ b/web/src/lib/connectors/connectors.tsx @@ -43,6 +43,7 @@ export interface Option { currentCredential: Credential | null ) => boolean; wrapInCollapsible?: boolean; + disabled?: boolean | ((currentCredential: Credential | null) => boolean); } export interface SelectOption extends Option { @@ -60,6 +61,7 @@ export interface ListOption extends Option { export interface TextOption extends Option { type: "text"; default?: string; + initial?: string | ((currentCredential: Credential | null) => string); isTextArea?: boolean; } @@ -105,6 +107,7 @@ export interface TabOption extends Option { export interface ConnectionConfiguration { description: string; subtext?: string; + initialConnectorName?: string; // a key in the credential to prepopulate the connector name field values: ( | BooleanOption | ListOption @@ -389,6 +392,7 @@ export const connectorConfigs: Record< }, confluence: { description: "Configure Confluence connector", + initialConnectorName: "cloud_name", values: [ { type: "checkbox", @@ -399,6 +403,12 @@ export const connectorConfigs: Record< default: true, description: "Check if this is a Confluence Cloud instance, uncheck for Confluence Server/Data Center", + disabled: (currentCredential) => { + if (currentCredential?.credential_json?.confluence_refresh_token) { + return true; + } + return false; + }, }, { type: "text", @@ -406,6 +416,15 @@ export const connectorConfigs: Record< label: "Wiki Base URL", name: "wiki_base", optional: false, + initial: (currentCredential) => { + return currentCredential?.credential_json?.wiki_base ?? ""; + }, + disabled: (currentCredential) => { + if (currentCredential?.credential_json?.confluence_refresh_token) { + return true; + } + return false; + }, description: "The base URL of your Confluence instance (e.g., https://your-domain.atlassian.net/wiki)", }, diff --git a/web/src/lib/connectors/oauth.ts b/web/src/lib/connectors/oauth.ts index d3472ccba..7100f2609 100644 --- a/web/src/lib/connectors/oauth.ts +++ b/web/src/lib/connectors/oauth.ts @@ -27,6 +27,9 @@ export async function getConnectorOauthRedirectUrl( export function useOAuthDetails(sourceType: ValidSources) { return useSWR( `/api/connector/oauth/details/${sourceType}`, - errorHandlingFetcher + errorHandlingFetcher, + { + shouldRetryOnError: false, + } ); } diff --git a/web/src/lib/googleConnector.ts b/web/src/lib/googleConnector.ts new file mode 100644 index 000000000..4e78c8584 --- /dev/null +++ b/web/src/lib/googleConnector.ts @@ -0,0 +1,120 @@ +import useSWR, { mutate } from "swr"; +import { FetchError, errorHandlingFetcher } from "@/lib/fetcher"; +import { Credential } from "@/lib/connectors/credentials"; +import { ConnectorSnapshot } from "@/lib/connectors/connectors"; +import { ValidSources } from "@/lib/types"; +import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib"; + +// Constants for service names to avoid typos +export const GOOGLE_SERVICES = { + GMAIL: "gmail", + GOOGLE_DRIVE: "google-drive", +} as const; + +export const useGoogleAppCredential = (service: "gmail" | "google_drive") => { + const endpoint = `/api/manage/admin/connector/${ + service === "gmail" ? GOOGLE_SERVICES.GMAIL : GOOGLE_SERVICES.GOOGLE_DRIVE + }/app-credential`; + + return useSWR<{ client_id: string }, FetchError>( + endpoint, + errorHandlingFetcher + ); +}; + +export const useGoogleServiceAccountKey = ( + service: "gmail" | "google_drive" +) => { + const endpoint = `/api/manage/admin/connector/${ + service === "gmail" ? GOOGLE_SERVICES.GMAIL : GOOGLE_SERVICES.GOOGLE_DRIVE + }/service-account-key`; + + return useSWR<{ service_account_email: string }, FetchError>( + endpoint, + errorHandlingFetcher + ); +}; + +export const useGoogleCredentials = ( + source: ValidSources.Gmail | ValidSources.GoogleDrive +) => { + return useSWR[]>( + buildSimilarCredentialInfoURL(source), + errorHandlingFetcher, + { refreshInterval: 5000 } + ); +}; + +export const useConnectorsByCredentialId = (credential_id: number | null) => { + let url: string | null = null; + if (credential_id !== null) { + url = `/api/manage/admin/connector?credential=${credential_id}`; + } + const swrResponse = useSWR(url, errorHandlingFetcher); + + return { + ...swrResponse, + refreshConnectorsByCredentialId: () => mutate(url), + }; +}; + +export const checkCredentialsFetched = ( + appCredentialData: any, + appCredentialError: FetchError | undefined, + serviceAccountKeyData: any, + serviceAccountKeyError: FetchError | undefined +) => { + const appCredentialSuccessfullyFetched = + appCredentialData || + (appCredentialError && appCredentialError.status === 404); + + const serviceAccountKeySuccessfullyFetched = + serviceAccountKeyData || + (serviceAccountKeyError && serviceAccountKeyError.status === 404); + + return { + appCredentialSuccessfullyFetched, + serviceAccountKeySuccessfullyFetched, + }; +}; + +export const filterUploadedCredentials = < + T extends { authentication_method?: string }, +>( + credentials: Credential[] | undefined +): { credential_id: number | null; uploadedCredentials: Credential[] } => { + let credential_id = null; + let uploadedCredentials: Credential[] = []; + + if (credentials) { + uploadedCredentials = credentials.filter( + (credential) => + credential.credential_json.authentication_method !== "oauth_interactive" + ); + + if (uploadedCredentials.length > 0) { + credential_id = uploadedCredentials[0].id; + } + } + + return { credential_id, uploadedCredentials }; +}; + +export const checkConnectorsExist = ( + connectors: ConnectorSnapshot[] | undefined +): boolean => { + return !!connectors && connectors.length > 0; +}; + +export const refreshAllGoogleData = ( + source: ValidSources.Gmail | ValidSources.GoogleDrive +) => { + mutate(buildSimilarCredentialInfoURL(source)); + + const service = + source === ValidSources.Gmail + ? GOOGLE_SERVICES.GMAIL + : GOOGLE_SERVICES.GOOGLE_DRIVE; + mutate(`/api/manage/admin/connector/${service}/app-credential`); + mutate(`/api/manage/admin/connector/${service}/service-account-key`); +}; diff --git a/web/src/lib/oauth_utils.ts b/web/src/lib/oauth_utils.ts index 5f4329a15..db3342efe 100644 --- a/web/src/lib/oauth_utils.ts +++ b/web/src/lib/oauth_utils.ts @@ -1,5 +1,7 @@ import { - OAuthGoogleDriveCallbackResponse, + OAuthBaseCallbackResponse, + OAuthConfluenceFinalizeResponse, + OAuthConfluencePrepareFinalizationResponse, OAuthPrepareAuthorizationResponse, OAuthSlackCallbackResponse, } from "./types"; @@ -53,6 +55,10 @@ export async function handleOAuthAuthorizationResponse( return handleOAuthGoogleDriveAuthorizationResponse(code, state); } + if (connector === "confluence") { + return handleOAuthConfluenceAuthorizationResponse(code, state); + } + return; } @@ -75,7 +81,7 @@ export async function handleOAuthSlackAuthorizationResponse( }); if (!response.ok) { - let errorDetails = `Failed to handle OAuth authorization response: ${response.status}`; + let errorDetails = `Failed to handle OAuth Slack authorization response: ${response.status}`; try { const responseBody = await response.text(); // Read the body as text @@ -96,12 +102,10 @@ export async function handleOAuthSlackAuthorizationResponse( return data; } -// server side handler to process the oauth redirect callback -// https://api.slack.com/authentication/oauth-v2#exchanging export async function handleOAuthGoogleDriveAuthorizationResponse( code: string, state: string -): Promise { +): Promise { const url = `/api/oauth/connector/google-drive/callback?code=${encodeURIComponent( code )}&state=${encodeURIComponent(state)}`; @@ -115,7 +119,7 @@ export async function handleOAuthGoogleDriveAuthorizationResponse( }); if (!response.ok) { - let errorDetails = `Failed to handle OAuth authorization response: ${response.status}`; + let errorDetails = `Failed to handle OAuth Google Drive authorization response: ${response.status}`; try { const responseBody = await response.text(); // Read the body as text @@ -132,6 +136,137 @@ export async function handleOAuthGoogleDriveAuthorizationResponse( } // Parse the JSON response - const data = (await response.json()) as OAuthGoogleDriveCallbackResponse; + const data = (await response.json()) as OAuthBaseCallbackResponse; + return data; +} + +// call server side helper +// https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps +export async function handleOAuthConfluenceAuthorizationResponse( + code: string, + state: string +): Promise { + const url = `/api/oauth/connector/confluence/callback?code=${encodeURIComponent( + code + )}&state=${encodeURIComponent(state)}`; + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ code, state }), + }); + + if (!response.ok) { + let errorDetails = `Failed to handle OAuth Confluence authorization response: ${response.status}`; + + try { + const responseBody = await response.text(); // Read the body as text + errorDetails += `\nResponse Body: ${responseBody}`; + } catch (err) { + if (err instanceof Error) { + errorDetails += `\nUnable to read response body: ${err.message}`; + } else { + errorDetails += `\nUnable to read response body: Unknown error type`; + } + } + + throw new Error(errorDetails); + } + + // Parse the JSON response + const data = (await response.json()) as OAuthBaseCallbackResponse; + return data; +} + +export async function handleOAuthPrepareFinalization( + connector: string, + credential: number +) { + if (connector === "confluence") { + return handleOAuthConfluencePrepareFinalization(credential); + } + + return; +} + +// call server side helper +// https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps +export async function handleOAuthConfluencePrepareFinalization( + credential: number +): Promise { + const url = `/api/oauth/connector/confluence/accessible-resources?credential_id=${encodeURIComponent( + credential + )}`; + + const response = await fetch(url, { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + let errorDetails = `Failed to handle OAuth Confluence prepare finalization response: ${response.status}`; + + try { + const responseBody = await response.text(); // Read the body as text + errorDetails += `\nResponse Body: ${responseBody}`; + } catch (err) { + if (err instanceof Error) { + errorDetails += `\nUnable to read response body: ${err.message}`; + } else { + errorDetails += `\nUnable to read response body: Unknown error type`; + } + } + + throw new Error(errorDetails); + } + + // Parse the JSON response + const data = + (await response.json()) as OAuthConfluencePrepareFinalizationResponse; + return data; +} + +export async function handleOAuthConfluenceFinalize( + credential_id: number, + cloud_id: string, + cloud_name: string, + cloud_url: string +): Promise { + const url = `/api/oauth/connector/confluence/finalize?credential_id=${encodeURIComponent( + credential_id + )}&cloud_id=${encodeURIComponent(cloud_id)}&cloud_name=${encodeURIComponent( + cloud_name + )}&cloud_url=${encodeURIComponent(cloud_url)}`; + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + let errorDetails = `Failed to handle OAuth Confluence finalization response: ${response.status}`; + + try { + const responseBody = await response.text(); // Read the body as text + errorDetails += `\nResponse Body: ${responseBody}`; + } catch (err) { + if (err instanceof Error) { + errorDetails += `\nUnable to read response body: ${err.message}`; + } else { + errorDetails += `\nUnable to read response body: Unknown error type`; + } + } + + throw new Error(errorDetails); + } + + // Parse the JSON response + const data = (await response.json()) as OAuthConfluenceFinalizeResponse; return data; } diff --git a/web/src/lib/sources.ts b/web/src/lib/sources.ts index 846918986..7a3341256 100644 --- a/web/src/lib/sources.ts +++ b/web/src/lib/sources.ts @@ -120,6 +120,7 @@ export const SOURCE_METADATA_MAP: SourceMap = { displayName: "Confluence", category: SourceCategory.Wiki, docs: "https://docs.onyx.app/connectors/confluence", + oauthSupported: true, }, jira: { icon: JiraIcon, diff --git a/web/src/lib/types.ts b/web/src/lib/types.ts index 7288fe905..c667bc156 100644 --- a/web/src/lib/types.ts +++ b/web/src/lib/types.ts @@ -167,18 +167,36 @@ export interface OAuthPrepareAuthorizationResponse { url: string; } -export interface OAuthSlackCallbackResponse { +export interface OAuthBaseCallbackResponse { success: boolean; message: string; - team_id: string; - authed_user_id: string; + finalize_url: string | null; redirect_on_success: string; } -export interface OAuthGoogleDriveCallbackResponse { +export interface OAuthSlackCallbackResponse extends OAuthBaseCallbackResponse { + team_id: string; + authed_user_id: string; +} + +export interface ConfluenceAccessibleResource { + id: string; + name: string; + url: string; + scopes: string[]; + avatarUrl: string; +} + +export interface OAuthConfluencePrepareFinalizationResponse { success: boolean; message: string; - redirect_on_success: string; + accessible_resources: ConfluenceAccessibleResource[]; +} + +export interface OAuthConfluenceFinalizeResponse { + success: boolean; + message: string; + redirect_url: string; } export interface CCPairBasicInfo { @@ -255,6 +273,7 @@ export interface ChannelConfig { channel_name: string; respond_tag_only?: boolean; respond_to_bots?: boolean; + is_ephemeral?: boolean; show_continue_in_web_ui?: boolean; respond_member_group_list?: string[]; answer_filters?: AnswerFilterOption[]; @@ -382,6 +401,7 @@ export const oauthSupportedSources: ConfigurableSources[] = [ ValidSources.Slack, // NOTE: temporarily disabled until our GDrive App is approved // ValidSources.GoogleDrive, + ValidSources.Confluence, ]; export type OAuthSupportedSource = (typeof oauthSupportedSources)[number];