mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 20:39:29 +02:00
Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/schema-translate-map
# Conflicts: # backend/ee/onyx/server/oauth.py # backend/onyx/background/celery/tasks/indexing/tasks.py # backend/onyx/db/search_settings.py # backend/onyx/key_value_store/store.py # backend/onyx/onyxbot/slack/handlers/handle_buttons.py # backend/tests/integration/common_utils/reset.py
This commit is contained in:
commit
7acbadd825
1
.github/CODEOWNERS
vendored
Normal file
1
.github/CODEOWNERS
vendored
Normal file
@ -0,0 +1 @@
|
||||
* @onyx-dot-app/onyx-core-team
|
94
.github/workflows/nightly-scan-licenses.yml
vendored
94
.github/workflows/nightly-scan-licenses.yml
vendored
@ -53,24 +53,90 @@ jobs:
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: ${{ always() }}
|
||||
if: always()
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@0.29.0
|
||||
# with:
|
||||
# sarif_file: trivy-results.sarif
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
@ -0,0 +1,55 @@
|
||||
"""add background_reindex_enabled field
|
||||
|
||||
Revision ID: b7c2b63c4a03
|
||||
Revises: f11b408e39d3
|
||||
Create Date: 2024-03-26 12:34:56.789012
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7c2b63c4a03"
|
||||
down_revision = "f11b408e39d3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add background_reindex_enabled column with default value of True
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"background_reindex_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="true",
|
||||
),
|
||||
)
|
||||
|
||||
# Add embedding_precision column with default value of FLOAT
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"embedding_precision",
|
||||
sa.Enum(EmbeddingPrecision, native_enum=False),
|
||||
nullable=False,
|
||||
server_default=EmbeddingPrecision.FLOAT.name,
|
||||
),
|
||||
)
|
||||
|
||||
# Add reduced_dimension column with default value of None
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column("reduced_dimension", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the background_reindex_enabled column
|
||||
op.drop_column("search_settings", "background_reindex_enabled")
|
||||
op.drop_column("search_settings", "embedding_precision")
|
||||
op.drop_column("search_settings", "reduced_dimension")
|
@ -0,0 +1,36 @@
|
||||
"""force lowercase all users
|
||||
|
||||
Revision ID: f11b408e39d3
|
||||
Revises: 3bd4c84fe72f
|
||||
Create Date: 2025-02-26 17:04:55.683500
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f11b408e39d3"
|
||||
down_revision = "3bd4c84fe72f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing user emails to lowercase
|
||||
from alembic import op
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
|
||||
# 2) Add a check constraint to ensure emails are always lowercase
|
||||
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
from alembic import op
|
||||
|
||||
op.drop_constraint("ensure_lowercase_email", "user", type_="check")
|
@ -0,0 +1,42 @@
|
||||
"""lowercase multi-tenant user auth
|
||||
|
||||
Revision ID: 34e3630c7f32
|
||||
Revises: a4f6ee863c47
|
||||
Create Date: 2025-02-26 15:03:01.211894
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "34e3630c7f32"
|
||||
down_revision = "a4f6ee863c47"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing rows to lowercase
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE user_tenant_mapping
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
# 2) Add a check constraint so that emails cannot be written in uppercase
|
||||
op.create_check_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
"email = LOWER(email)",
|
||||
schema="public",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
op.drop_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
schema="public",
|
||||
type_="check",
|
||||
)
|
@ -59,10 +59,14 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
|
||||
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
|
||||
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
|
||||
)
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
|
||||
)
|
||||
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
|
@ -424,7 +424,7 @@ def _validate_curator_status__no_commit(
|
||||
)
|
||||
|
||||
# if the user is a curator in any of their groups, set their role to CURATOR
|
||||
# otherwise, set their role to BASIC
|
||||
# otherwise, set their role to BASIC only if they were previously a CURATOR
|
||||
if curator_relationships:
|
||||
user.role = UserRole.CURATOR
|
||||
elif user.role == UserRole.CURATOR:
|
||||
@ -631,7 +631,16 @@ def update_user_group(
|
||||
removed_users = db_session.scalars(
|
||||
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
|
||||
).unique()
|
||||
_validate_curator_status__no_commit(db_session, list(removed_users))
|
||||
|
||||
# Filter out admin and global curator users before validating curator status
|
||||
users_to_validate = [
|
||||
user
|
||||
for user in removed_users
|
||||
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
|
||||
]
|
||||
|
||||
if users_to_validate:
|
||||
_validate_curator_status__no_commit(db_session, users_to_validate)
|
||||
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
@ -9,12 +9,16 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -342,7 +346,8 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@ -354,7 +359,11 @@ def confluence_doc_sync(
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
confluence_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
provider = OnyxDBCredentialsProvider(
|
||||
get_current_tenant_id(), "confluence", cc_pair.credential_id
|
||||
)
|
||||
confluence_connector.set_credentials_provider(provider)
|
||||
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
|
@ -1,9 +1,11 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -61,13 +63,27 @@ def _build_group_member_email_map(
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
confluence_client = build_confluence_client(
|
||||
credentials=cc_pair.credential.credential_json,
|
||||
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
|
||||
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
|
||||
)
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
|
||||
url = wiki_base.rstrip("/")
|
||||
|
||||
probe_kwargs = {
|
||||
"max_backoff_retries": 6,
|
||||
"max_backoff_seconds": 10,
|
||||
}
|
||||
|
||||
final_kwargs = {
|
||||
"max_backoff_retries": 10,
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
confluence_client = OnyxConfluence(is_cloud, url, provider)
|
||||
confluence_client._probe_connection(**probe_kwargs)
|
||||
confluence_client._initialize_connection(**final_kwargs)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
|
@ -32,7 +32,8 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
@ -145,7 +145,8 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
@ -119,6 +119,7 @@ def _build_onyx_groups(
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
# Initialize connector and build credential/service objects
|
||||
|
@ -123,7 +123,8 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
@ -28,6 +28,7 @@ DocSyncFuncType = Callable[
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
str,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
list[ExternalUserGroup],
|
||||
|
@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
)
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.onyx.server.oauth import router as oauth_router
|
||||
from ee.onyx.server.oauth.api import router as oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
@ -152,4 +152,8 @@ def get_application() -> FastAPI:
|
||||
# environment variable. Used to automate deployment for multiple environments.
|
||||
seed_db()
|
||||
|
||||
# for debugging discovered routes
|
||||
# for route in application.router.routes:
|
||||
# print(f"Path: {route.path}, Methods: {route.methods}")
|
||||
|
||||
return application
|
||||
|
@ -22,7 +22,7 @@ from onyx.onyxbot.slack.blocks import get_restate_blocks
|
||||
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.utils.logger import OnyxLoggingAdapter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@ -216,7 +216,7 @@ def _handle_standard_answers(
|
||||
all_blocks = restate_question_blocks + answer_blocks
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
receiver_ids=receiver_ids,
|
||||
@ -231,6 +231,7 @@ def _handle_standard_answers(
|
||||
client=client,
|
||||
channel=message_info.channel_to_respond,
|
||||
thread_ts=slack_thread_id,
|
||||
receiver_ids=receiver_ids,
|
||||
)
|
||||
|
||||
return True
|
||||
|
91
backend/ee/onyx/server/oauth/api.py
Normal file
91
backend/ee/onyx/server/oauth/api.py
Normal file
@ -0,0 +1,91 @@
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
|
||||
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
|
||||
from ee.onyx.server.oauth.slack import SlackOAuth
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_admin_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
|
||||
session: str | None = None
|
||||
if connector == DocumentSource.SLACK:
|
||||
if not DEV_MODE:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state)
|
||||
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.CONFLUENCE:
|
||||
if not DEV_MODE:
|
||||
oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
session = ConfluenceCloudOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
if not DEV_MODE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"The document source type {connector} failed to generate an OAuth session.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
3
backend/ee/onyx/server/oauth/api_router.py
Normal file
3
backend/ee/onyx/server/oauth/api_router.py
Normal file
@ -0,0 +1,3 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/oauth")
|
361
backend/ee/onyx/server/oauth/confluence_cloud.py
Normal file
361
backend/ee/onyx/server/oauth/confluence_cloud.py
Normal file
@ -0,0 +1,361 @@
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
class AccessibleResources(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
url: str
|
||||
scopes: list[str]
|
||||
avatarUrl: str
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL
|
||||
|
||||
ACCESSIBLE_RESOURCE_URL = (
|
||||
"https://api.atlassian.com/oauth/token/accessible-resources"
|
||||
)
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
# classic scope
|
||||
"read:confluence-space.summary%20"
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence%20"
|
||||
"search:confluence%20"
|
||||
# granular scope
|
||||
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
|
||||
"offline_access"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code
|
||||
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def generate_finalize_url(cls, credential_id: int) -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}"
|
||||
|
||||
|
||||
@router.post("/connector/confluence/callback")
|
||||
def confluence_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Handles the backend logic for the frontend page that the user is redirected to
|
||||
after visiting the oauth authorization url."""
|
||||
|
||||
if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Confluence Cloud client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = ConfluenceCloudOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
ConfluenceCloudOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
"client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
token_response: ConfluenceCloudOAuth.TokenResponse | None = None
|
||||
|
||||
try:
|
||||
token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json(
|
||||
response.text
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
"Confluence Cloud OAuth failed during code/token exchange."
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=token_response.expires_in)
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={
|
||||
"confluence_access_token": token_response.access_token,
|
||||
"confluence_refresh_token": token_response.refresh_token,
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"expires_in": token_response.expires_in,
|
||||
"scope": token_response.scope,
|
||||
},
|
||||
admin_public=True,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
name="Confluence Cloud OAuth",
|
||||
)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud OAuth completed successfully.",
|
||||
"finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id),
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/connector/confluence/accessible-resources")
|
||||
def confluence_oauth_accessible_resources(
|
||||
credential_id: int,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Atlassian's API is weird and does not supply us with enough info to be in a
|
||||
usable state after authorizing. All API's require a cloud id. We have to list
|
||||
the accessible resources/sites and let the user choose which site to use."""
|
||||
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if not credential:
|
||||
raise HTTPException(400, f"Credential {credential_id} not found.")
|
||||
|
||||
credential_dict = credential.credential_json
|
||||
access_token = credential_dict["confluence_access_token"]
|
||||
|
||||
try:
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.get(
|
||||
ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
accessible_resources_data = response.json()
|
||||
|
||||
# Validate the list of AccessibleResources
|
||||
try:
|
||||
accessible_resources = [
|
||||
ConfluenceCloudOAuth.AccessibleResources(**resource)
|
||||
for resource in accessible_resources_data
|
||||
]
|
||||
except ValidationError as e:
|
||||
raise RuntimeError(f"Failed to parse accessible resources: {e}")
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud get accessible resources completed successfully.",
|
||||
"accessible_resources": [
|
||||
resource.model_dump() for resource in accessible_resources
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/connector/confluence/finalize")
|
||||
def confluence_oauth_finalize(
|
||||
credential_id: int,
|
||||
cloud_id: str,
|
||||
cloud_name: str,
|
||||
cloud_url: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Saves the info for the selected cloud site to the credential.
|
||||
This is the final step in the confluence oauth flow where after the traditional
|
||||
OAuth process, the user has to select a site to associate with the credentials.
|
||||
After this, the credential is usable."""
|
||||
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
|
||||
)
|
||||
|
||||
new_credential_json: dict[str, Any] = dict(credential.credential_json)
|
||||
new_credential_json["cloud_id"] = cloud_id
|
||||
new_credential_json["cloud_name"] = cloud_name
|
||||
new_credential_json["wiki_base"] = cloud_url
|
||||
|
||||
try:
|
||||
update_credential_json(credential_id, new_credential_json, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud OAuth finalized successfully.",
|
||||
"redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence",
|
||||
}
|
||||
)
|
229
backend/ee/onyx/server/oauth/google_drive.py
Normal file
229
backend/ee/onyx/server/oauth/google_drive.py
Normal file
@ -0,0 +1,229 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = GoogleDriveOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"finalize_url": None,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
197
backend/ee/onyx/server/oauth/slack.py
Normal file
197
backend/ee/onyx/server/oauth/slack.py
Normal file
@ -0,0 +1,197 @@
|
||||
import base64
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = SlackOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = SlackOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"finalize_url": None,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
}
|
||||
)
|
@ -138,6 +138,7 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
|
@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
def fetch_billing_information(
|
||||
tenant_id: str,
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Check if the response indicates no subscription
|
||||
if (
|
||||
isinstance(response_data, dict)
|
||||
and "subscribed" in response_data
|
||||
and not response_data["subscribed"]
|
||||
):
|
||||
return SubscriptionStatusResponse(**response_data)
|
||||
|
||||
# Otherwise, parse as BillingInformation
|
||||
return BillingInformation(**response_data)
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
|
@ -78,7 +78,7 @@ class CloudEmbedding:
|
||||
self._closed = False
|
||||
|
||||
async def _embed_openai(
|
||||
self, texts: list[str], model: str | None
|
||||
self, texts: list[str], model: str | None, reduced_dimension: int | None
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
@ -91,7 +91,11 @@ class CloudEmbedding:
|
||||
final_embeddings: list[Embedding] = []
|
||||
try:
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(input=text_batch, model=model)
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
@ -223,9 +227,10 @@ class CloudEmbedding:
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> list[Embedding]:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name)
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
@ -326,6 +331,7 @@ async def embed_text(
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
reduced_dimension: int | None,
|
||||
gpu_type: str = "UNKNOWN",
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
@ -369,6 +375,7 @@ async def embed_text(
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
reduced_dimension=reduced_dimension,
|
||||
)
|
||||
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
@ -508,6 +515,7 @@ async def process_embed_request(
|
||||
text_type=embed_request.text_type,
|
||||
api_url=embed_request.api_url,
|
||||
api_version=embed_request.api_version,
|
||||
reduced_dimension=embed_request.reduced_dimension,
|
||||
prefix=prefix,
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
|
@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
user: User
|
||||
user: User | None = None
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
@ -420,15 +420,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = cast(User, await self.user_db.get_by_email(account_email))
|
||||
user = await self.user_db.get_by_email(account_email)
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
# Make sure user is not None before adding OAuth account
|
||||
if user is not None:
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
else:
|
||||
# This shouldn't happen since get_by_email would raise UserNotExists
|
||||
# but adding as a safeguard
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
# If user not found by OAuth account or email, create a new user
|
||||
except exceptions.UserNotExists:
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
@ -439,26 +444,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
# Add OAuth account only if user creation was successful
|
||||
if user is not None:
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to create user account"
|
||||
)
|
||||
|
||||
else:
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
# User exists, update OAuth account if needed
|
||||
if user is not None: # Add explicit check
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
# Ensure user is not None before proceeding
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to authenticate or create user"
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
|
@ -423,7 +423,7 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(cc_pair)
|
||||
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
update_connector_credential_pair(
|
||||
|
@ -23,9 +23,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.indexing.utils import _should_index
|
||||
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.utils import should_index
|
||||
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
|
||||
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
|
||||
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
|
||||
@ -61,7 +61,7 @@ from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.session import get_session_with_current_tenant
|
||||
from onyx.db.swap_index import check_index_swap
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
@ -406,7 +406,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
old_search_settings = check_and_perform_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
@ -439,6 +439,15 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
for search_settings_instance in search_settings_list:
|
||||
# skip non-live search settings that don't have background reindex enabled
|
||||
# those should just auto-change to live shortly after creation without
|
||||
# requiring any indexing till that point
|
||||
if (
|
||||
not search_settings_instance.status.is_current()
|
||||
and not search_settings_instance.background_reindex_enabled
|
||||
):
|
||||
continue
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
@ -456,23 +465,18 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
|
||||
if not _should_index(
|
||||
if not should_index(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
search_settings_instance=search_settings_instance,
|
||||
search_settings_primary=search_settings_primary,
|
||||
secondary_index_building=len(search_settings_list) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
reindex = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
# the indexing trigger is only checked and cleared with the primary search settings
|
||||
if search_settings_instance.status.is_current():
|
||||
# the indexing trigger is only checked and cleared with the current search settings
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
|
||||
reindex = True
|
||||
|
@ -346,11 +346,10 @@ def validate_indexing_fences(
|
||||
return
|
||||
|
||||
|
||||
def _should_index(
|
||||
def should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
search_settings_instance: SearchSettings,
|
||||
search_settings_primary: bool,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
@ -415,9 +414,9 @@ def _should_index(
|
||||
):
|
||||
return False
|
||||
|
||||
if search_settings_primary:
|
||||
if search_settings_instance.status.is_current():
|
||||
if cc_pair.indexing_trigger is not None:
|
||||
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
|
||||
# if a manual indexing trigger is on the cc pair, honor it for live search settings
|
||||
return True
|
||||
|
||||
# if no attempt has ever occurred, we should index regardless of refresh_freq
|
||||
|
@ -11,10 +11,27 @@ def emit_background_error(
|
||||
"""Currently just saves a row in the background_errors table.
|
||||
|
||||
In the future, could create notifications based on the severity."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
error_message = ""
|
||||
|
||||
# try to write to the db, but handle IntegrityError specifically
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
except IntegrityError as e:
|
||||
# Log an error if the cc_pair_id was deleted or any other exception occurs
|
||||
error_message = f"Failed to create background error: {str(e)}. Original message: {message}"
|
||||
except IntegrityError as e:
|
||||
# Log an error if the cc_pair_id was deleted or any other exception occurs
|
||||
error_message = (
|
||||
f"Failed to create background error: {str(e)}. Original message: {message}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not error_message:
|
||||
return
|
||||
|
||||
# if we get here from an IntegrityError, try to write the error message to the db
|
||||
# we need a new session because the first session is now invalid
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_background_error(db_session, error_message, None)
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -93,10 +93,11 @@ def _get_connector_runner(
|
||||
runnable_connector.validate_connector_settings()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
||||
logger.exception("Unable to instantiate connector.")
|
||||
# since we failed to even instantiate the connector, we pause the CCPair since
|
||||
# it will never succeed. Sometimes there are cases where the connector will
|
||||
# it will never succeed
|
||||
|
||||
# Sometimes there are cases where the connector will
|
||||
# intermittently fail to initialize in which case we should pass in
|
||||
# leave_connector_active=True to allow it to continue.
|
||||
# For example, if there is nightly maintenance on a Confluence Server instance,
|
||||
|
@ -11,17 +11,20 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import attachment_to_content
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
extract_text_from_confluence_html,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import attachment_to_content
|
||||
from onyx.connectors.confluence.utils import build_confluence_document_id
|
||||
from onyx.connectors.confluence.utils import datetime_from_string
|
||||
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@ -83,7 +86,9 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
)
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class ConfluenceConnector(
|
||||
LoadConnector, PollConnector, SlimConnector, CredentialsConnector
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_base: str,
|
||||
@ -102,7 +107,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
@ -137,6 +141,19 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
self.credentials_provider: CredentialsProviderInterface | None = None
|
||||
|
||||
self.probe_kwargs = {
|
||||
"max_backoff_retries": 6,
|
||||
"max_backoff_seconds": 10,
|
||||
}
|
||||
|
||||
self.final_kwargs = {
|
||||
"max_backoff_retries": 10,
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
@ -144,15 +161,22 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
return self._confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
|
||||
# for a list of other hidden constructor args
|
||||
self._confluence_client = build_confluence_client(
|
||||
credentials=credentials,
|
||||
is_cloud=self.is_cloud,
|
||||
wiki_base=self.wiki_base,
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
self.credentials_provider = credentials_provider
|
||||
|
||||
# raises exception if there's a problem
|
||||
confluence_client = OnyxConfluence(
|
||||
self.is_cloud, self.wiki_base, credentials_provider
|
||||
)
|
||||
return None
|
||||
confluence_client._probe_connection(**self.probe_kwargs)
|
||||
confluence_client._initialize_connection(**self.final_kwargs)
|
||||
|
||||
self._confluence_client = confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||
|
||||
def _construct_page_query(
|
||||
self,
|
||||
@ -202,12 +226,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return comment_string
|
||||
|
||||
def _convert_object_to_document(
|
||||
self, confluence_object: dict[str, Any]
|
||||
self,
|
||||
confluence_object: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Takes in a confluence object, extracts all metadata, and converts it into a document.
|
||||
If its a page, it extracts the text, adds the comments for the document text.
|
||||
If its an attachment, it just downloads the attachment and converts that into a document.
|
||||
|
||||
parent_content_id: if the object is an attachment, specifies the content id that
|
||||
the attachment is attached to
|
||||
"""
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
@ -226,7 +255,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
|
||||
elif confluence_object["type"] == "attachment":
|
||||
object_text = attachment_to_content(
|
||||
confluence_client=self.confluence_client, attachment=confluence_object
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=confluence_object,
|
||||
parent_content_id=parent_content_id,
|
||||
)
|
||||
|
||||
if object_text is None:
|
||||
@ -302,7 +333,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
doc = self._convert_object_to_document(attachment, confluence_page_id)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
|
@ -1,19 +1,37 @@
|
||||
import math
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from requests import HTTPError
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -22,12 +40,14 @@ logger = setup_logger()
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
@ -43,124 +63,349 @@ class ConfluenceUser(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 50
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
class OnyxConfluence:
|
||||
"""
|
||||
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
|
||||
This is a custom Confluence class that:
|
||||
|
||||
A. overrides the default Confluence class to add a custom CQL method.
|
||||
B.
|
||||
This is necessary because the default Confluence class does not properly support cql expansions.
|
||||
All methods are automatically wrapped with handle_confluence_rate_limit.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
CREDENTIAL_PREFIX = "connector:confluence:credential"
|
||||
CREDENTIAL_TTL = 300 # 5 min
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
is_cloud: bool,
|
||||
url: str,
|
||||
credentials_provider: CredentialsProviderInterface,
|
||||
) -> None:
|
||||
self._is_cloud = is_cloud
|
||||
self._url = url.rstrip("/")
|
||||
self._credentials_provider = credentials_provider
|
||||
|
||||
self.redis_client: Redis | None = None
|
||||
self.static_credentials: dict[str, Any] | None = None
|
||||
if self._credentials_provider.is_dynamic():
|
||||
self.redis_client = get_redis_client(
|
||||
tenant_id=credentials_provider.get_tenant_id()
|
||||
)
|
||||
else:
|
||||
self.static_credentials = self._credentials_provider.get_credentials()
|
||||
|
||||
self._confluence = Confluence(url)
|
||||
self.credential_key: str = (
|
||||
self.CREDENTIAL_PREFIX
|
||||
+ f":credential_{self._credentials_provider.get_provider_key()}"
|
||||
)
|
||||
|
||||
self._kwargs: Any = None
|
||||
|
||||
self.shared_base_kwargs = {
|
||||
"api_version": "cloud" if is_cloud else "latest",
|
||||
"backoff_and_retry": True,
|
||||
"cloud": is_cloud,
|
||||
}
|
||||
|
||||
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
|
||||
"""credential_json - the current json credentials
|
||||
Returns a tuple
|
||||
1. The up to date credentials
|
||||
2. True if the credentials were updated
|
||||
|
||||
This method is intended to be used within a distributed lock.
|
||||
Lock, call this, update credentials if the tokens were refreshed, then release
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
wrap it with handle_confluence_rate_limit.
|
||||
"""
|
||||
for attr_name in dir(self):
|
||||
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
setattr(
|
||||
self,
|
||||
attr_name,
|
||||
handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
# static credentials are preloaded, so no locking/redis required
|
||||
if self.static_credentials:
|
||||
return self.static_credentials, False
|
||||
|
||||
if not self.redis_client:
|
||||
raise RuntimeError("self.redis_client is None")
|
||||
|
||||
# dynamic credentials need locking
|
||||
# check redis first, then fallback to the DB
|
||||
credential_raw = self.redis_client.get(self.credential_key)
|
||||
if credential_raw is not None:
|
||||
credential_bytes = cast(bytes, credential_raw)
|
||||
credential_str = credential_bytes.decode("utf-8")
|
||||
credential_json: dict[str, Any] = json.loads(credential_str)
|
||||
else:
|
||||
credential_json = self._credentials_provider.get_credentials()
|
||||
|
||||
if "confluence_refresh_token" not in credential_json:
|
||||
# static credentials ... cache them permanently and return
|
||||
self.static_credentials = credential_json
|
||||
return credential_json, False
|
||||
|
||||
# check if we should refresh tokens. we're deciding to refresh halfway
|
||||
# to expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
created_at = datetime.fromisoformat(credential_json["created_at"])
|
||||
expires_in: int = credential_json["expires_in"]
|
||||
renew_at = created_at + timedelta(seconds=expires_in // 2)
|
||||
if now <= renew_at:
|
||||
# cached/current credentials are reasonably up to date
|
||||
return credential_json, False
|
||||
|
||||
# we need to refresh
|
||||
logger.info("Renewing Confluence Cloud credentials...")
|
||||
new_credentials = confluence_refresh_tokens(
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID,
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET,
|
||||
credential_json["cloud_id"],
|
||||
credential_json["confluence_refresh_token"],
|
||||
)
|
||||
|
||||
# store the new credentials to redis and to the db thru the provider
|
||||
# redis: we use a 5 min TTL because we are given a 10 minute grace period
|
||||
# when keys are rotated. it's easier to expire the cached credentials
|
||||
# reasonably frequently rather than trying to handle strong synchronization
|
||||
# between the db and redis everywhere the credentials might be updated
|
||||
new_credential_str = json.dumps(new_credentials)
|
||||
self.redis_client.set(
|
||||
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
|
||||
)
|
||||
self._credentials_provider.set_credentials(new_credentials)
|
||||
|
||||
return new_credentials, True
|
||||
|
||||
@staticmethod
|
||||
def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
oauth2_dict: dict[str, Any] = {}
|
||||
if "confluence_refresh_token" in credentials:
|
||||
oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
oauth2_dict["token"] = {}
|
||||
oauth2_dict["token"]["access_token"] = credentials[
|
||||
"confluence_access_token"
|
||||
]
|
||||
return oauth2_dict
|
||||
|
||||
def _probe_connection(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
|
||||
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
logger.info("Probing Confluence with OAuth Access Token.")
|
||||
|
||||
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
|
||||
credentials
|
||||
)
|
||||
url = (
|
||||
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
)
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url, oauth2=oauth2_dict, **merged_kwargs
|
||||
)
|
||||
else:
|
||||
logger.info("Probing Confluence with Personal Access Token.")
|
||||
url = self._url
|
||||
if self._is_cloud:
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
username=credentials["confluence_username"],
|
||||
password=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
else:
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
token=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {url}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
|
||||
logger.info("Confluence probe succeeded.")
|
||||
|
||||
def _initialize_connection(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called externally to init the connection in a thread safe manner."""
|
||||
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
self._confluence = self._initialize_connection_helper(
|
||||
credentials, **merged_kwargs
|
||||
)
|
||||
self._kwargs = merged_kwargs
|
||||
|
||||
def _initialize_connection_helper(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Confluence:
|
||||
"""Called internally to init the connection. Distributed locking
|
||||
to prevent multiple threads from modifying the credentials
|
||||
must be handled around this function."""
|
||||
|
||||
confluence = None
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
logger.info("Connecting to Confluence Cloud with OAuth Access Token.")
|
||||
|
||||
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials)
|
||||
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
|
||||
else:
|
||||
logger.info("Connecting to Confluence with Personal Access Token.")
|
||||
if self._is_cloud:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
username=credentials["confluence_username"],
|
||||
password=credentials["confluence_access_token"],
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
token=credentials["confluence_access_token"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return confluence
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def _make_rate_limited_confluence_method(
|
||||
self, name: str, credential_provider: CredentialsProviderInterface | None
|
||||
) -> Callable[..., Any]:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
try:
|
||||
if credential_provider:
|
||||
with credential_provider:
|
||||
credentials, renewed = self._renew_credentials()
|
||||
if renewed:
|
||||
self._confluence = self._initialize_connection_helper(
|
||||
credentials, **self._kwargs
|
||||
)
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
return attr(*args, **kwargs)
|
||||
else:
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return wrapped_call
|
||||
|
||||
# def _wrap_methods(self) -> None:
|
||||
# """
|
||||
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
# wrap it with handle_confluence_rate_limit.
|
||||
# """
|
||||
# for attr_name in dir(self):
|
||||
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
# setattr(
|
||||
# self,
|
||||
# attr_name,
|
||||
# handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
# )
|
||||
|
||||
# def _ensure_token_valid(self) -> None:
|
||||
# if self._token_is_expired():
|
||||
# self._refresh_token()
|
||||
# # Re-init the Confluence client with the originally stored args
|
||||
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Dynamically intercept attribute/method access."""
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
# If it's not a method, just return it after ensuring token validity
|
||||
if not callable(attr):
|
||||
return attr
|
||||
|
||||
# skip methods that start with "_"
|
||||
if name.startswith("_"):
|
||||
return attr
|
||||
|
||||
# wrap the method with our retry handler
|
||||
rate_limited_method: Callable[
|
||||
..., Any
|
||||
] = self._make_rate_limited_confluence_method(name, self._credentials_provider)
|
||||
|
||||
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
|
||||
return rate_limited_method(*args, **kwargs)
|
||||
|
||||
return wrapped_method
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
|
||||
@ -507,63 +752,212 @@ class OnyxConfluence(Confluence):
|
||||
return response
|
||||
|
||||
|
||||
def _validate_connector_configuration(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> None:
|
||||
# test connection with direct client, no retries
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
url=wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=6,
|
||||
max_backoff_seconds=10,
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
global _USER_ID_TO_DISPLAY_NAME_CACHE
|
||||
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_userkey(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
if not found_display_name:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_accountid(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
|
||||
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
if "api.atlassian.com" in confluence_client.url:
|
||||
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
|
||||
if not parent_content_id:
|
||||
logger.warning(
|
||||
"parent_content_id is required to download attachments from Confluence Cloud!"
|
||||
)
|
||||
return None
|
||||
|
||||
download_link = (
|
||||
confluence_client.url
|
||||
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
|
||||
)
|
||||
else:
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
|
||||
# why are we using session.get here? we probably won't retry these ... is that ok?
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
return extracted_text
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
body = confluence_object["body"]
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
def build_confluence_client(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> OnyxConfluence:
|
||||
try:
|
||||
_validate_connector_configuration(
|
||||
credentials=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(str(e))
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=wiki_base.rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=10,
|
||||
max_backoff_seconds=60,
|
||||
cloud=is_cloud,
|
||||
)
|
||||
return format_document_soup(soup)
|
||||
|
@ -1,185 +1,38 @@
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
pass
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: "OnyxConfluence", user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
global _USER_ID_TO_DISPLAY_NAME_CACHE
|
||||
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_userkey(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
if not found_display_name:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_accountid(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
|
||||
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
body = confluence_object["body"]
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return format_document_soup(soup)
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
@ -193,49 +46,6 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
]
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
@ -284,6 +94,137 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
return datetime_object
|
||||
|
||||
|
||||
def confluence_refresh_tokens(
|
||||
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
|
||||
) -> dict[str, Any]:
|
||||
# rotate the refresh and access token
|
||||
# Note that access tokens are only good for an hour in confluence cloud,
|
||||
# so we're going to have problems if the connector runs for longer
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair
|
||||
response = requests.post(
|
||||
CONFLUENCE_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
token_response = TokenResponse.model_validate_json(response.text)
|
||||
except Exception:
|
||||
raise RuntimeError("Confluence Cloud token refresh failed.")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=token_response.expires_in)
|
||||
|
||||
new_credentials: dict[str, Any] = {}
|
||||
new_credentials["confluence_access_token"] = token_response.access_token
|
||||
new_credentials["confluence_refresh_token"] = token_response.refresh_token
|
||||
new_credentials["created_at"] = now.isoformat()
|
||||
new_credentials["expires_at"] = expires_at.isoformat()
|
||||
new_credentials["expires_in"] = token_response.expires_in
|
||||
new_credentials["scope"] = token_response.scope
|
||||
new_credentials["cloud_id"] = cloud_id
|
||||
return new_credentials
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except requests.HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
def get_single_param_from_url(url: str, param: str) -> str | None:
|
||||
"""Get a parameter from a url"""
|
||||
parsed_url = urlparse(url)
|
||||
|
135
backend/onyx/connectors/credentials_provider.py
Normal file
135
backend/onyx/connectors/credentials_provider.py
Normal file
@ -0,0 +1,135 @@
|
||||
import uuid
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import Credential
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
class OnyxDBCredentialsProvider(
|
||||
CredentialsProviderInterface["OnyxDBCredentialsProvider"]
|
||||
):
|
||||
"""Implementation to allow the connector to callback and update credentials in the db.
|
||||
Required in cases where credentials can rotate while the connector is running.
|
||||
"""
|
||||
|
||||
LOCK_TTL = 900 # TTL of the lock
|
||||
|
||||
def __init__(self, tenant_id: str, connector_name: str, credential_id: int):
|
||||
self._tenant_id = tenant_id
|
||||
self._connector_name = connector_name
|
||||
self._credential_id = credential_id
|
||||
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# lock used to prevent overlapping renewal of credentials
|
||||
self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}"
|
||||
self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL)
|
||||
|
||||
def __enter__(self) -> "OnyxDBCredentialsProvider":
|
||||
acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL)
|
||||
if not acquired:
|
||||
raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}")
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
"""Release the lock when exiting the context."""
|
||||
if self._lock and self._lock.owned():
|
||||
self._lock.release()
|
||||
|
||||
def get_tenant_id(self) -> str | None:
|
||||
return self._tenant_id
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return str(self._credential_id)
|
||||
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
credential = db_session.execute(
|
||||
select(Credential).where(Credential.id == self._credential_id)
|
||||
).scalar_one()
|
||||
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
return credential.credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
try:
|
||||
credential = db_session.execute(
|
||||
select(Credential)
|
||||
.where(Credential.id == self._credential_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
credential.credential_json = credential_json
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class OnyxStaticCredentialsProvider(
|
||||
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
|
||||
):
|
||||
"""Implementation (a very simple one!) to handle static credentials."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
connector_name: str,
|
||||
credential_json: dict[str, Any],
|
||||
):
|
||||
self._tenant_id = tenant_id
|
||||
self._connector_name = connector_name
|
||||
self._credential_json = credential_json
|
||||
|
||||
self._provider_key = str(uuid.uuid4())
|
||||
|
||||
def __enter__(self) -> "OnyxStaticCredentialsProvider":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def get_tenant_id(self) -> str | None:
|
||||
return self._tenant_id
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return self._provider_key
|
||||
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
return self._credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
self._credential_json = credential_json
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return False
|
@ -12,6 +12,7 @@ from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
@ -32,6 +33,7 @@ from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@ -57,6 +59,7 @@ from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.models import Credential
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
class ConnectorMissingException(Exception):
|
||||
@ -167,10 +170,17 @@ def instantiate_connector(
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
if isinstance(connector, CredentialsConnector):
|
||||
provider = OnyxDBCredentialsProvider(
|
||||
get_current_tenant_id(), str(source), credential.id
|
||||
)
|
||||
connector.set_credentials_provider(provider)
|
||||
else:
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
return connector
|
||||
|
||||
|
@ -1,7 +1,10 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -111,6 +114,69 @@ class OAuthConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
T = TypeVar("T", bound="CredentialsProviderInterface")
|
||||
|
||||
|
||||
class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||
@abc.abstractmethod
|
||||
def __enter__(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_tenant_id(self) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_provider_key(self) -> str:
|
||||
"""a unique key that the connector can use to lock around a credential
|
||||
that might be used simultaneously.
|
||||
|
||||
Will typically be the credential id, but can also just be something random
|
||||
in cases when there is nothing to lock (aka static credentials)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dynamic(self) -> bool:
|
||||
"""If dynamic, the credentials may change during usage ... maening the client
|
||||
needs to use the locking features of the credentials provider to operate
|
||||
correctly.
|
||||
|
||||
If static, the client can simply reference the credentials once and use them
|
||||
through the entire indexing run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CredentialsConnector(BaseConnector):
|
||||
"""Implement this if the connector needs to be able to read and write credentials
|
||||
on the fly. Typically used with shared credentials/tokens that might be renewed
|
||||
at any time."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
@ -72,6 +72,7 @@ def make_slack_api_rate_limited(
|
||||
@wraps(call)
|
||||
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
last_exception = None
|
||||
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
# Make the API call
|
||||
|
@ -42,6 +42,10 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
|
||||
# Threshold for determining when to replace vs append iframe content
|
||||
IFRAME_TEXT_LENGTH_THRESHOLD = 700
|
||||
# Message indicating JavaScript is disabled, which often appears when scraping fails
|
||||
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
|
||||
|
||||
|
||||
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
|
||||
@ -138,7 +142,8 @@ def get_internal_links(
|
||||
# Account for malformed backslashes in URLs
|
||||
href = href.replace("\\", "/")
|
||||
|
||||
if should_ignore_pound and "#" in href:
|
||||
# "#!" indicates the page is using a hashbang URL, which is a client-side routing technique
|
||||
if should_ignore_pound and "#" in href and "#!" not in href:
|
||||
href = href.split("#")[0]
|
||||
|
||||
if not is_valid_url(href):
|
||||
@ -288,6 +293,7 @@ class WebConnector(LoadConnector):
|
||||
and converts them into documents"""
|
||||
visited_links: set[str] = set()
|
||||
to_visit: list[str] = self.to_visit_list
|
||||
content_hashes = set()
|
||||
|
||||
if not to_visit:
|
||||
raise ValueError("No URLs to visit")
|
||||
@ -302,29 +308,30 @@ class WebConnector(LoadConnector):
|
||||
playwright, context = start_playwright()
|
||||
restart_playwright = False
|
||||
while to_visit:
|
||||
current_url = to_visit.pop()
|
||||
if current_url in visited_links:
|
||||
initial_url = to_visit.pop()
|
||||
if initial_url in visited_links:
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
visited_links.add(initial_url)
|
||||
|
||||
try:
|
||||
protected_url_check(current_url)
|
||||
protected_url_check(initial_url)
|
||||
except Exception as e:
|
||||
last_error = f"Invalid URL {current_url} due to {e}"
|
||||
last_error = f"Invalid URL {initial_url} due to {e}"
|
||||
logger.warning(last_error)
|
||||
continue
|
||||
|
||||
logger.info(f"Visiting {current_url}")
|
||||
index = len(visited_links)
|
||||
logger.info(f"{index}: Visiting {initial_url}")
|
||||
|
||||
try:
|
||||
check_internet_connection(current_url)
|
||||
check_internet_connection(initial_url)
|
||||
if restart_playwright:
|
||||
playwright, context = start_playwright()
|
||||
restart_playwright = False
|
||||
|
||||
if current_url.split(".")[-1] == "pdf":
|
||||
if initial_url.split(".")[-1] == "pdf":
|
||||
# PDF files are not checked for links
|
||||
response = requests.get(current_url)
|
||||
response = requests.get(initial_url)
|
||||
page_text, metadata = read_pdf_file(
|
||||
file=io.BytesIO(response.content)
|
||||
)
|
||||
@ -332,10 +339,10 @@ class WebConnector(LoadConnector):
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=current_url,
|
||||
sections=[Section(link=current_url, text=page_text)],
|
||||
id=initial_url,
|
||||
sections=[Section(link=initial_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=current_url.split("/")[-1],
|
||||
semantic_identifier=initial_url.split("/")[-1],
|
||||
metadata=metadata,
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
@ -347,21 +354,29 @@ class WebConnector(LoadConnector):
|
||||
continue
|
||||
|
||||
page = context.new_page()
|
||||
page_response = page.goto(current_url)
|
||||
|
||||
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
|
||||
page_response = page.goto(
|
||||
initial_url,
|
||||
timeout=30000, # 30 seconds
|
||||
)
|
||||
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified")
|
||||
if page_response
|
||||
else None
|
||||
)
|
||||
final_page = page.url
|
||||
if final_page != current_url:
|
||||
logger.info(f"Redirected to {final_page}")
|
||||
protected_url_check(final_page)
|
||||
current_url = final_page
|
||||
if current_url in visited_links:
|
||||
logger.info("Redirected page already indexed")
|
||||
final_url = page.url
|
||||
if final_url != initial_url:
|
||||
protected_url_check(final_url)
|
||||
initial_url = final_url
|
||||
if initial_url in visited_links:
|
||||
logger.info(
|
||||
f"{index}: {initial_url} redirected to {final_url} - already indexed"
|
||||
)
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
logger.info(f"{index}: {initial_url} redirected to {final_url}")
|
||||
visited_links.add(initial_url)
|
||||
|
||||
if self.scroll_before_scraping:
|
||||
scroll_attempts = 0
|
||||
@ -379,26 +394,58 @@ class WebConnector(LoadConnector):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
if self.recursive:
|
||||
internal_links = get_internal_links(base_url, current_url, soup)
|
||||
internal_links = get_internal_links(base_url, initial_url, soup)
|
||||
for link in internal_links:
|
||||
if link not in visited_links:
|
||||
to_visit.append(link)
|
||||
|
||||
if page_response and str(page_response.status)[0] in ("4", "5"):
|
||||
last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response"
|
||||
last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response"
|
||||
logger.info(last_error)
|
||||
continue
|
||||
|
||||
parsed_html = web_html_cleanup(soup, self.mintlify_cleanup)
|
||||
|
||||
"""For websites containing iframes that need to be scraped,
|
||||
the code below can extract text from within these iframes.
|
||||
"""
|
||||
logger.debug(
|
||||
f"{index}: Length of cleaned text {len(parsed_html.cleaned_text)}"
|
||||
)
|
||||
if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text:
|
||||
iframe_count = page.frame_locator("iframe").locator("html").count()
|
||||
if iframe_count > 0:
|
||||
iframe_texts = (
|
||||
page.frame_locator("iframe")
|
||||
.locator("html")
|
||||
.all_inner_texts()
|
||||
)
|
||||
document_text = "\n".join(iframe_texts)
|
||||
""" 700 is the threshold value for the length of the text extracted
|
||||
from the iframe based on the issue faced """
|
||||
if len(parsed_html.cleaned_text) < IFRAME_TEXT_LENGTH_THRESHOLD:
|
||||
parsed_html.cleaned_text = document_text
|
||||
else:
|
||||
parsed_html.cleaned_text += "\n" + document_text
|
||||
|
||||
# Sometimes pages with #! will serve duplicate content
|
||||
# There are also just other ways this can happen
|
||||
hashed_text = hash((parsed_html.title, parsed_html.cleaned_text))
|
||||
if hashed_text in content_hashes:
|
||||
logger.info(
|
||||
f"{index}: Skipping duplicate title + content for {initial_url}"
|
||||
)
|
||||
continue
|
||||
content_hashes.add(hashed_text)
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=current_url,
|
||||
id=initial_url,
|
||||
sections=[
|
||||
Section(link=current_url, text=parsed_html.cleaned_text)
|
||||
Section(link=initial_url, text=parsed_html.cleaned_text)
|
||||
],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=parsed_html.title or current_url,
|
||||
semantic_identifier=parsed_html.title or initial_url,
|
||||
metadata={},
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
@ -410,7 +457,7 @@ class WebConnector(LoadConnector):
|
||||
|
||||
page.close()
|
||||
except Exception as e:
|
||||
last_error = f"Failed to fetch '{current_url}': {e}"
|
||||
last_error = f"Failed to fetch '{initial_url}': {e}"
|
||||
logger.exception(last_error)
|
||||
playwright.stop()
|
||||
restart_playwright = True
|
||||
|
@ -76,6 +76,10 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
provider_type=search_settings.provider_type,
|
||||
index_name=search_settings.index_name,
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
embedding_precision=search_settings.embedding_precision,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
# Whether switching to this model requires re-indexing
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
# Reranking Details
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
|
@ -168,7 +168,7 @@ def get_chat_sessions_by_user(
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
stmt = stmt.order_by(desc(ChatSession.time_updated))
|
||||
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
@ -962,6 +962,7 @@ def translate_db_message_to_chat_message_detail(
|
||||
chat_message.sub_questions
|
||||
),
|
||||
refined_answer_improvement=chat_message.refined_answer_improvement,
|
||||
is_agentic=chat_message.is_agentic,
|
||||
error=chat_message.error,
|
||||
)
|
||||
|
||||
|
@ -360,18 +360,13 @@ def backend_update_credential_json(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_credential(
|
||||
def _delete_credential_internal(
|
||||
credential: Credential,
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
"""Internal utility function to handle the actual deletion of a credential"""
|
||||
associated_connectors = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.credential_id == credential_id)
|
||||
@ -416,6 +411,35 @@ def delete_credential(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_credential_for_user(
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Delete a credential that belongs to a specific user"""
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
_delete_credential_internal(credential, credential_id, db_session, force)
|
||||
|
||||
|
||||
def delete_credential(
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Delete a credential regardless of ownership (admin function)"""
|
||||
credential = fetch_credential_by_id(credential_id, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(f"Credential by provided id {credential_id} does not exist")
|
||||
|
||||
_delete_credential_internal(credential, credential_id, db_session, force)
|
||||
|
||||
|
||||
def create_initial_public_credential(db_session: Session) -> None:
|
||||
error_msg = (
|
||||
"DB is not in a valid initial state."
|
||||
|
@ -63,6 +63,9 @@ class IndexModelStatus(str, PyEnum):
|
||||
PRESENT = "PRESENT"
|
||||
FUTURE = "FUTURE"
|
||||
|
||||
def is_current(self) -> bool:
|
||||
return self == IndexModelStatus.PRESENT
|
||||
|
||||
|
||||
class ChatSessionSharedStatus(str, PyEnum):
|
||||
PUBLIC = "public"
|
||||
@ -83,3 +86,11 @@ class AccessType(str, PyEnum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
SYNC = "sync"
|
||||
|
||||
|
||||
class EmbeddingPrecision(str, PyEnum):
|
||||
# matches vespa tensor type
|
||||
# only support float / bfloat16 for now, since there's not a
|
||||
# good reason to specify anything else
|
||||
BFLOAT16 = "bfloat16"
|
||||
FLOAT = "float"
|
||||
|
@ -7,6 +7,7 @@ from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import validates
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
|
||||
@ -45,7 +46,13 @@ from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import AccessType, IndexingMode, SyncType, SyncStatus
|
||||
from onyx.db.enums import (
|
||||
AccessType,
|
||||
EmbeddingPrecision,
|
||||
IndexingMode,
|
||||
SyncType,
|
||||
SyncStatus,
|
||||
)
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.constants import TokenRateLimitScope
|
||||
@ -206,6 +213,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
return value.lower() if value else value
|
||||
|
||||
@property
|
||||
def password_configured(self) -> bool:
|
||||
"""
|
||||
@ -711,6 +722,23 @@ class SearchSettings(Base):
|
||||
ForeignKey("embedding_provider.provider_type"), nullable=True
|
||||
)
|
||||
|
||||
# Whether switching to this model should re-index all connectors in the background
|
||||
# if no re-index is needed, will be ignored. Only used during the switch-over process.
|
||||
background_reindex_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# allows for quantization -> less memory usage for a small performance hit
|
||||
embedding_precision: Mapped[EmbeddingPrecision] = mapped_column(
|
||||
Enum(EmbeddingPrecision, native_enum=False)
|
||||
)
|
||||
|
||||
# can be used to reduce dimensionality of vectors and save memory with
|
||||
# a small performance hit. More details in the `Reducing embedding dimensions`
|
||||
# section here:
|
||||
# https://platform.openai.com/docs/guides/embeddings#embedding-models
|
||||
# If not specified, will just use the model_dim without any reduction.
|
||||
# NOTE: this is only currently available for OpenAI models
|
||||
reduced_dimension: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Mini and Large Chunks (large chunk also checks for model max context)
|
||||
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
@ -792,6 +820,12 @@ class SearchSettings(Base):
|
||||
self.multipass_indexing, self.model_name, self.provider_type
|
||||
)
|
||||
|
||||
@property
|
||||
def final_embedding_dim(self) -> int:
|
||||
if self.reduced_dimension:
|
||||
return self.reduced_dimension
|
||||
return self.model_dim
|
||||
|
||||
@staticmethod
|
||||
def can_use_large_chunks(
|
||||
multipass: bool, model_name: str, provider_type: EmbeddingProvider | None
|
||||
@ -1756,6 +1790,7 @@ class ChannelConfig(TypedDict):
|
||||
channel_name: str | None # None for default channel config
|
||||
respond_tag_only: NotRequired[bool] # defaults to False
|
||||
respond_to_bots: NotRequired[bool] # defaults to False
|
||||
is_ephemeral: NotRequired[bool] # defaults to False
|
||||
respond_member_group_list: NotRequired[list[str]]
|
||||
answer_filters: NotRequired[list[AllowedAnswerFilters]]
|
||||
# If None then no follow up
|
||||
@ -2270,6 +2305,10 @@ class UserTenantMapping(Base):
|
||||
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
|
||||
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
return value.lower() if value else value
|
||||
|
||||
|
||||
# This is a mapping from tenant IDs to anonymous user paths
|
||||
class TenantAnonymousUserPath(Base):
|
||||
|
@ -209,13 +209,21 @@ def create_update_persona(
|
||||
if not all_prompt_ids:
|
||||
raise ValueError("No prompt IDs provided")
|
||||
|
||||
is_default_persona: bool | None = create_persona_request.is_default_persona
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
if not create_persona_request.is_public:
|
||||
raise ValueError("Cannot make a default persona non public")
|
||||
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
if user:
|
||||
# Curators can edit default personas, but not make them
|
||||
if (
|
||||
user.role == UserRole.CURATOR
|
||||
or user.role == UserRole.GLOBAL_CURATOR
|
||||
):
|
||||
is_default_persona = None
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
|
||||
persona = upsert_persona(
|
||||
persona_id=persona_id,
|
||||
@ -241,7 +249,7 @@ def create_update_persona(
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
is_default_persona=is_default_persona,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
@ -428,7 +436,7 @@ def upsert_persona(
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
is_default_persona: bool = False,
|
||||
is_default_persona: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
@ -523,7 +531,11 @@ def upsert_persona(
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = is_default_persona
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
else existing_persona.is_default_persona
|
||||
)
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
if document_sets is not None:
|
||||
@ -575,7 +587,9 @@ def upsert_persona(
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
search_start_date=search_start_date,
|
||||
is_default_persona=is_default_persona,
|
||||
is_default_persona=is_default_persona
|
||||
if is_default_persona is not None
|
||||
else False,
|
||||
labels=labels or [],
|
||||
)
|
||||
db_session.add(new_persona)
|
||||
|
@ -13,6 +13,7 @@ from onyx.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.llm import fetch_embedding_provider
|
||||
from onyx.db.models import CloudEmbeddingProvider
|
||||
from onyx.db.models import IndexAttempt
|
||||
@ -59,12 +60,15 @@ def create_search_settings(
|
||||
index_name=search_settings.index_name,
|
||||
provider_type=search_settings.provider_type,
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
embedding_precision=search_settings.embedding_precision,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
rerank_api_key=search_settings.rerank_api_key,
|
||||
num_rerank=search_settings.num_rerank,
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
)
|
||||
|
||||
db_session.add(embedding_model)
|
||||
@ -305,6 +309,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
|
||||
model_dim=(
|
||||
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
),
|
||||
embedding_precision=(EmbeddingPrecision.FLOAT),
|
||||
normalize=(
|
||||
NORMALIZE_EMBEDDINGS
|
||||
if is_overridden
|
||||
@ -322,6 +327,7 @@ def get_new_default_embedding_model() -> IndexingSetting:
|
||||
return IndexingSetting(
|
||||
model_name=DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=DOC_EMBEDDING_DIM,
|
||||
embedding_precision=(EmbeddingPrecision.FLOAT),
|
||||
normalize=NORMALIZE_EMBEDDINGS,
|
||||
query_prefix=ASYM_QUERY_PREFIX,
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
|
@ -8,10 +8,12 @@ from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import (
|
||||
count_unique_cc_pairs_with_successful_index_attempts,
|
||||
)
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -19,7 +21,49 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
def _perform_index_swap(
|
||||
db_session: Session,
|
||||
current_search_settings: SearchSettings,
|
||||
secondary_search_settings: SearchSettings,
|
||||
all_cc_pairs: list[ConnectorCredentialPair],
|
||||
) -> None:
|
||||
"""Swap the indices and expire the old one."""
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
update_search_settings_status(
|
||||
search_settings=current_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_search_settings_status(
|
||||
search_settings=secondary_search_settings,
|
||||
new_status=IndexModelStatus.PRESENT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if len(all_cc_pairs) > 0:
|
||||
kv_store = get_kv_store()
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
# Expire jobs for the now past index/embedding model
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
# Recount aggregates
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
# remove the old index from the vector db
|
||||
document_index = get_default_document_index(secondary_search_settings, None)
|
||||
document_index.ensure_indices_exist(
|
||||
primary_embedding_dim=secondary_search_settings.final_embedding_dim,
|
||||
primary_embedding_precision=secondary_search_settings.embedding_precision,
|
||||
# just finished swap, no more secondary index
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
|
||||
|
||||
def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
"""Get count of cc-pairs and count of successful index_attempts for the
|
||||
new model grouped by connector + credential, if it's the same, then assume
|
||||
new index is done building. If so, swap the indices and expire the old one.
|
||||
@ -27,52 +71,45 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
Returns None if search settings did not change, or the old search settings if they
|
||||
did change.
|
||||
"""
|
||||
|
||||
old_search_settings = None
|
||||
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
if not search_settings:
|
||||
if not secondary_search_settings:
|
||||
return None
|
||||
|
||||
# If the secondary search settings are not configured to reindex in the background,
|
||||
# we can just swap over instantly
|
||||
if not secondary_search_settings.background_reindex_enabled:
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
_perform_index_swap(
|
||||
db_session=db_session,
|
||||
current_search_settings=current_search_settings,
|
||||
secondary_search_settings=secondary_search_settings,
|
||||
all_cc_pairs=all_cc_pairs,
|
||||
)
|
||||
return current_search_settings
|
||||
|
||||
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
)
|
||||
|
||||
# Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this
|
||||
# function is correct. The unique_cc_indexings are specifically for the existing cc-pairs
|
||||
old_search_settings = None
|
||||
if unique_cc_indexings > cc_pair_count:
|
||||
logger.error("More unique indexings than cc pairs, should not occur")
|
||||
|
||||
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
|
||||
# Swap indices
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
update_search_settings_status(
|
||||
search_settings=current_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
_perform_index_swap(
|
||||
db_session=db_session,
|
||||
current_search_settings=current_search_settings,
|
||||
secondary_search_settings=secondary_search_settings,
|
||||
all_cc_pairs=all_cc_pairs,
|
||||
)
|
||||
|
||||
update_search_settings_status(
|
||||
search_settings=search_settings,
|
||||
new_status=IndexModelStatus.PRESENT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if cc_pair_count > 0:
|
||||
kv_store = get_kv_store()
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
# Expire jobs for the now past index/embedding model
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
# Recount aggregates
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
old_search_settings = current_search_settings
|
||||
old_search_settings = current_search_settings
|
||||
|
||||
return old_search_settings
|
||||
|
@ -6,6 +6,7 @@ from typing import Any
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
@ -145,17 +146,21 @@ class Verifiable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
index_embedding_dim: int,
|
||||
primary_embedding_dim: int,
|
||||
primary_embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the document index exists and is consistent with the expectations in the code.
|
||||
|
||||
Parameters:
|
||||
- index_embedding_dim: Vector dimensionality for the vector similarity part of the search
|
||||
- primary_embedding_dim: Vector dimensionality for the vector similarity part of the search
|
||||
- primary_embedding_precision: Precision of the vector similarity part of the search
|
||||
- secondary_index_embedding_dim: Vector dimensionality of the secondary index being built
|
||||
behind the scenes. The secondary index should only be built when switching
|
||||
embedding models therefore this dim should be different from the primary index.
|
||||
- secondary_index_embedding_precision: Precision of the vector similarity part of the secondary index
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -164,6 +169,7 @@ class Verifiable(abc.ABC):
|
||||
def register_multitenant_indices(
|
||||
indices: list[str],
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
"""
|
||||
Register multitenant indices with the document index.
|
||||
|
@ -37,7 +37,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
summary: dynamic
|
||||
}
|
||||
# Title embedding (x1)
|
||||
field title_embedding type tensor<float>(x[VARIABLE_DIM]) {
|
||||
field title_embedding type tensor<EMBEDDING_PRECISION>(x[VARIABLE_DIM]) {
|
||||
indexing: attribute | index
|
||||
attribute {
|
||||
distance-metric: angular
|
||||
@ -45,7 +45,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
}
|
||||
# Content embeddings (chunk + optional mini chunks embeddings)
|
||||
# "t" and "x" are arbitrary names, not special keywords
|
||||
field embeddings type tensor<float>(t{},x[VARIABLE_DIM]) {
|
||||
field embeddings type tensor<EMBEDDING_PRECISION>(t{},x[VARIABLE_DIM]) {
|
||||
indexing: attribute | index
|
||||
attribute {
|
||||
distance-metric: angular
|
||||
|
@ -5,4 +5,7 @@
|
||||
<allow
|
||||
until="DATE_REPLACEMENT"
|
||||
comment="We need to be able to update the schema for updates to the Onyx schema">indexing-change</allow>
|
||||
<allow
|
||||
until='DATE_REPLACEMENT'
|
||||
comment="Prevents old alt indices from interfering with changes">field-type-change</allow>
|
||||
</validation-overrides>
|
||||
|
@ -310,6 +310,11 @@ def query_vespa(
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
f"Exception: {str(e)}"
|
||||
+ (
|
||||
f"\nResponse: {e.response.text}"
|
||||
if isinstance(e, httpx.HTTPStatusError)
|
||||
else ""
|
||||
)
|
||||
)
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
|
||||
|
@ -26,6 +26,7 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
@ -63,6 +64,7 @@ from onyx.document_index.vespa_constants import DATE_REPLACEMENT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import EMBEDDING_PRECISION_REPLACEMENT_PAT
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
|
||||
@ -112,6 +114,21 @@ def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str:
|
||||
return "\n".join(doc_lines)
|
||||
|
||||
|
||||
def _replace_template_values_in_schema(
|
||||
schema_template: str,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> str:
|
||||
return (
|
||||
schema_template.replace(
|
||||
EMBEDDING_PRECISION_REPLACEMENT_PAT, embedding_precision.value
|
||||
)
|
||||
.replace(DANSWER_CHUNK_REPLACEMENT_PAT, index_name)
|
||||
.replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
|
||||
)
|
||||
|
||||
|
||||
def add_ngrams_to_schema(schema_content: str) -> str:
|
||||
# Add the match blocks containing gram and gram-size to title and content fields
|
||||
schema_content = re.sub(
|
||||
@ -163,8 +180,10 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
index_embedding_dim: int,
|
||||
primary_embedding_dim: int,
|
||||
primary_embedding_precision: EmbeddingPrecision,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
secondary_index_embedding_precision: EmbeddingPrecision | None,
|
||||
) -> None:
|
||||
if MULTI_TENANT:
|
||||
logger.info(
|
||||
@ -221,18 +240,29 @@ class VespaIndex(DocumentIndex):
|
||||
schema_template = schema_f.read()
|
||||
schema_template = schema_template.replace(TENANT_ID_PAT, "")
|
||||
|
||||
schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
|
||||
schema = _replace_template_values_in_schema(
|
||||
schema_template,
|
||||
self.index_name,
|
||||
primary_embedding_dim,
|
||||
primary_embedding_precision,
|
||||
)
|
||||
|
||||
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
|
||||
schema = schema.replace(TENANT_ID_PAT, "")
|
||||
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
|
||||
|
||||
if self.secondary_index_name:
|
||||
upcoming_schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
|
||||
if secondary_index_embedding_dim is None:
|
||||
raise ValueError("Secondary index embedding dimension is required")
|
||||
if secondary_index_embedding_precision is None:
|
||||
raise ValueError("Secondary index embedding precision is required")
|
||||
|
||||
upcoming_schema = _replace_template_values_in_schema(
|
||||
schema_template,
|
||||
self.secondary_index_name,
|
||||
secondary_index_embedding_dim,
|
||||
secondary_index_embedding_precision,
|
||||
)
|
||||
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
|
||||
|
||||
zip_file = in_memory_zip_from_file_bytes(zip_dict)
|
||||
@ -251,6 +281,7 @@ class VespaIndex(DocumentIndex):
|
||||
def register_multitenant_indices(
|
||||
indices: list[str],
|
||||
embedding_dims: list[int],
|
||||
embedding_precisions: list[EmbeddingPrecision],
|
||||
) -> None:
|
||||
if not MULTI_TENANT:
|
||||
raise ValueError("Multi-tenant is not enabled")
|
||||
@ -309,13 +340,14 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
for i, index_name in enumerate(indices):
|
||||
embedding_dim = embedding_dims[i]
|
||||
embedding_precision = embedding_precisions[i]
|
||||
logger.info(
|
||||
f"Creating index: {index_name} with embedding dimension: {embedding_dim}"
|
||||
)
|
||||
|
||||
schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
|
||||
schema = _replace_template_values_in_schema(
|
||||
schema_template, index_name, embedding_dim, embedding_precision
|
||||
)
|
||||
schema = schema.replace(
|
||||
TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else ""
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ from onyx.configs.app_configs import VESPA_TENANT_PORT
|
||||
from onyx.configs.constants import SOURCE_TYPE
|
||||
|
||||
VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
|
||||
EMBEDDING_PRECISION_REPLACEMENT_PAT = "EMBEDDING_PRECISION"
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
|
||||
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
|
||||
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
|
||||
|
@ -38,6 +38,7 @@ class IndexingEmbedder(ABC):
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
deployment_name: str | None,
|
||||
reduced_dimension: int | None,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
@ -60,6 +61,7 @@ class IndexingEmbedder(ABC):
|
||||
api_url=api_url,
|
||||
api_version=api_version,
|
||||
deployment_name=deployment_name,
|
||||
reduced_dimension=reduced_dimension,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
@ -87,6 +89,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -99,6 +102,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
api_url,
|
||||
api_version,
|
||||
deployment_name,
|
||||
reduced_dimension,
|
||||
callback,
|
||||
)
|
||||
|
||||
@ -219,6 +223,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
api_url=search_settings.api_url,
|
||||
api_version=search_settings.api_version,
|
||||
deployment_name=search_settings.deployment_name,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
@ -5,6 +5,7 @@ from pydantic import Field
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
@ -143,10 +144,20 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
model_dim: int
|
||||
index_name: str | None
|
||||
multipass_indexing: bool
|
||||
embedding_precision: EmbeddingPrecision
|
||||
reduced_dimension: int | None = None
|
||||
|
||||
background_reindex_enabled: bool = True
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
@property
|
||||
def final_embedding_dim(self) -> int:
|
||||
if self.reduced_dimension:
|
||||
return self.reduced_dimension
|
||||
return self.model_dim
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting":
|
||||
return cls(
|
||||
@ -158,6 +169,9 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
provider_type=search_settings.provider_type,
|
||||
index_name=search_settings.index_name,
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
embedding_precision=search_settings.embedding_precision,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
)
|
||||
|
||||
|
||||
|
@ -19,6 +19,7 @@ from onyx.utils.special_types import JSON_ro
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
|
@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.standard_oauth import router as oauth_router
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.folder.api import router as folder_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
@ -323,7 +322,6 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
@ -89,6 +89,7 @@ class EmbeddingModel:
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@ -100,6 +101,7 @@ class EmbeddingModel:
|
||||
self.api_url = api_url
|
||||
self.api_version = api_version
|
||||
self.deployment_name = deployment_name
|
||||
self.reduced_dimension = reduced_dimension
|
||||
self.tokenizer = get_tokenizer(
|
||||
model_name=model_name, provider_type=provider_type
|
||||
)
|
||||
@ -188,6 +190,7 @@ class EmbeddingModel:
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
api_url=self.api_url,
|
||||
reduced_dimension=self.reduced_dimension,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
@ -300,6 +303,7 @@ class EmbeddingModel:
|
||||
retrim_content=retrim_content,
|
||||
api_version=search_settings.api_version,
|
||||
deployment_name=search_settings.deployment_name,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
)
|
||||
|
||||
|
||||
|
@ -31,12 +31,18 @@ from onyx.onyxbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.icons import source_to_github_img_link
|
||||
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessage
|
||||
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageChannelConfig
|
||||
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageMessageInfo
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import build_continue_in_web_ui_id
|
||||
from onyx.onyxbot.slack.utils import build_feedback_id
|
||||
from onyx.onyxbot.slack.utils import build_publish_ephemeral_message_id
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.onyxbot.slack.utils import translate_vespa_highlight_to_slack
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
@ -105,6 +111,77 @@ def _build_qa_feedback_block(
|
||||
)
|
||||
|
||||
|
||||
def _build_ephemeral_publication_block(
|
||||
channel_id: str,
|
||||
chat_message_id: int,
|
||||
message_info: SlackMessageInfo,
|
||||
original_question_ts: str,
|
||||
channel_conf: ChannelConfig,
|
||||
feedback_reminder_id: str | None = None,
|
||||
) -> Block:
|
||||
# check whether the message is in a thread
|
||||
if (
|
||||
message_info is not None
|
||||
and message_info.msg_to_respond is not None
|
||||
and message_info.thread_to_respond is not None
|
||||
and (message_info.msg_to_respond == message_info.thread_to_respond)
|
||||
):
|
||||
respond_ts = None
|
||||
else:
|
||||
respond_ts = original_question_ts
|
||||
|
||||
action_values_ephemeral_message_channel_config = (
|
||||
ActionValuesEphemeralMessageChannelConfig(
|
||||
channel_name=channel_conf.get("channel_name"),
|
||||
respond_tag_only=channel_conf.get("respond_tag_only"),
|
||||
respond_to_bots=channel_conf.get("respond_to_bots"),
|
||||
is_ephemeral=channel_conf.get("is_ephemeral", False),
|
||||
respond_member_group_list=channel_conf.get("respond_member_group_list"),
|
||||
answer_filters=channel_conf.get("answer_filters"),
|
||||
follow_up_tags=channel_conf.get("follow_up_tags"),
|
||||
show_continue_in_web_ui=channel_conf.get("show_continue_in_web_ui", False),
|
||||
)
|
||||
)
|
||||
|
||||
action_values_ephemeral_message_message_info = (
|
||||
ActionValuesEphemeralMessageMessageInfo(
|
||||
bypass_filters=message_info.bypass_filters,
|
||||
channel_to_respond=message_info.channel_to_respond,
|
||||
msg_to_respond=message_info.msg_to_respond,
|
||||
email=message_info.email,
|
||||
sender_id=message_info.sender_id,
|
||||
thread_messages=[],
|
||||
is_bot_msg=message_info.is_bot_msg,
|
||||
is_bot_dm=message_info.is_bot_dm,
|
||||
thread_to_respond=respond_ts,
|
||||
)
|
||||
)
|
||||
|
||||
action_values_ephemeral_message = ActionValuesEphemeralMessage(
|
||||
original_question_ts=original_question_ts,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
chat_message_id=chat_message_id,
|
||||
message_info=action_values_ephemeral_message_message_info,
|
||||
channel_conf=action_values_ephemeral_message_channel_config,
|
||||
)
|
||||
|
||||
return ActionsBlock(
|
||||
block_id=build_publish_ephemeral_message_id(original_question_ts),
|
||||
elements=[
|
||||
ButtonElement(
|
||||
action_id=SHOW_EVERYONE_ACTION_ID,
|
||||
text="📢 Share with Everyone",
|
||||
value=action_values_ephemeral_message.model_dump_json(),
|
||||
),
|
||||
ButtonElement(
|
||||
action_id=KEEP_TO_YOURSELF_ACTION_ID,
|
||||
text="🤫 Keep to Yourself",
|
||||
value=action_values_ephemeral_message.model_dump_json(),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_document_feedback_blocks() -> Block:
|
||||
return SectionBlock(
|
||||
text=(
|
||||
@ -486,16 +563,21 @@ def build_slack_response_blocks(
|
||||
use_citations: bool,
|
||||
feedback_reminder_id: str | None,
|
||||
skip_ai_feedback: bool = False,
|
||||
offer_ephemeral_publication: bool = False,
|
||||
expecting_search_result: bool = False,
|
||||
skip_restated_question: bool = False,
|
||||
) -> list[Block]:
|
||||
"""
|
||||
This function is a top level function that builds all the blocks for the Slack response.
|
||||
It also handles combining all the blocks together.
|
||||
"""
|
||||
# If called with the OnyxBot slash command, the question is lost so we have to reshow it
|
||||
restate_question_block = get_restate_blocks(
|
||||
message_info.thread_messages[-1].message, message_info.is_bot_msg
|
||||
)
|
||||
if not skip_restated_question:
|
||||
restate_question_block = get_restate_blocks(
|
||||
message_info.thread_messages[-1].message, message_info.is_bot_msg
|
||||
)
|
||||
else:
|
||||
restate_question_block = []
|
||||
|
||||
if expecting_search_result:
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
@ -520,12 +602,36 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
follow_up_block = []
|
||||
if channel_conf and channel_conf.get("follow_up_tags") is not None:
|
||||
if (
|
||||
channel_conf
|
||||
and channel_conf.get("follow_up_tags") is not None
|
||||
and not channel_conf.get("is_ephemeral", False)
|
||||
):
|
||||
follow_up_block.append(
|
||||
_build_follow_up_block(message_id=answer.chat_message_id)
|
||||
)
|
||||
|
||||
ai_feedback_block = []
|
||||
publish_ephemeral_message_block = []
|
||||
|
||||
if (
|
||||
offer_ephemeral_publication
|
||||
and answer.chat_message_id is not None
|
||||
and message_info.msg_to_respond is not None
|
||||
and channel_conf is not None
|
||||
):
|
||||
publish_ephemeral_message_block.append(
|
||||
_build_ephemeral_publication_block(
|
||||
channel_id=message_info.channel_to_respond,
|
||||
chat_message_id=answer.chat_message_id,
|
||||
original_question_ts=message_info.msg_to_respond,
|
||||
message_info=message_info,
|
||||
channel_conf=channel_conf,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
)
|
||||
|
||||
ai_feedback_block: list[Block] = []
|
||||
|
||||
if answer.chat_message_id is not None and not skip_ai_feedback:
|
||||
ai_feedback_block.append(
|
||||
_build_qa_feedback_block(
|
||||
@ -547,6 +653,7 @@ def build_slack_response_blocks(
|
||||
all_blocks = (
|
||||
restate_question_block
|
||||
+ answer_blocks
|
||||
+ publish_ephemeral_message_block
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
|
@ -2,6 +2,8 @@ from enum import Enum
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
SHOW_EVERYONE_ACTION_ID = "show-everyone"
|
||||
KEEP_TO_YOURSELF_ACTION_ID = "keep-to-yourself"
|
||||
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
|
||||
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
|
||||
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@ -5,21 +6,32 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
from slack_sdk.models.views import View
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.webhook import WebhookClient
|
||||
|
||||
from onyx.chat.models import ChatOnyxBotResponse
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import create_doc_retrieval_feedback
|
||||
from onyx.db.session import get_session_with_current_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks
|
||||
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.blocks import get_document_feedback_blocks
|
||||
from onyx.onyxbot.slack.config import get_slack_channel_config_for_bot_and_channel
|
||||
from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import FeedbackVisibility
|
||||
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID
|
||||
from onyx.onyxbot.slack.handlers.handle_message import (
|
||||
remove_scheduled_feedback_reminder,
|
||||
@ -35,15 +47,48 @@ from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
from onyx.onyxbot.slack.utils import get_feedback_visibility
|
||||
from onyx.onyxbot.slack.utils import read_slack_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import TenantSocketModeClient
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.server.query_and_chat.models import ChatMessageDetail
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _convert_db_doc_id_to_document_ids(
|
||||
citation_dict: dict[int, int], top_documents: list[SavedSearchDoc]
|
||||
) -> list[CitationInfo]:
|
||||
citation_list_with_document_id = []
|
||||
for citation_num, db_doc_id in citation_dict.items():
|
||||
if db_doc_id is not None:
|
||||
matching_doc = next(
|
||||
(d for d in top_documents if d.db_doc_id == db_doc_id), None
|
||||
)
|
||||
if matching_doc:
|
||||
citation_list_with_document_id.append(
|
||||
CitationInfo(
|
||||
citation_num=citation_num, document_id=matching_doc.document_id
|
||||
)
|
||||
)
|
||||
return citation_list_with_document_id
|
||||
|
||||
|
||||
def _build_citation_list(chat_message_detail: ChatMessageDetail) -> list[CitationInfo]:
|
||||
citation_dict = chat_message_detail.citations
|
||||
if citation_dict is None:
|
||||
return []
|
||||
else:
|
||||
top_documents = (
|
||||
chat_message_detail.context_docs.top_documents
|
||||
if chat_message_detail.context_docs
|
||||
else []
|
||||
)
|
||||
citation_list = _convert_db_doc_id_to_document_ids(citation_dict, top_documents)
|
||||
return citation_list
|
||||
|
||||
|
||||
def handle_doc_feedback_button(
|
||||
req: SocketModeRequest,
|
||||
client: TenantSocketModeClient,
|
||||
@ -58,7 +103,7 @@ def handle_doc_feedback_button(
|
||||
external_id = build_feedback_id(query_event_id, doc_id, doc_rank)
|
||||
|
||||
channel_id = req.payload["container"]["channel_id"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
thread_ts = req.payload["container"].get("thread_ts", None)
|
||||
|
||||
data = View(
|
||||
type="modal",
|
||||
@ -84,7 +129,7 @@ def handle_generate_answer_button(
|
||||
channel_id = req.payload["channel"]["id"]
|
||||
channel_name = req.payload["channel"]["name"]
|
||||
message_ts = req.payload["message"]["ts"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
thread_ts = req.payload["container"].get("thread_ts", None)
|
||||
user_id = req.payload["user"]["id"]
|
||||
expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={})
|
||||
email = expert_info.email if expert_info else None
|
||||
@ -106,7 +151,7 @@ def handle_generate_answer_button(
|
||||
|
||||
# tell the user that we're working on it
|
||||
# Send an ephemeral message to the user that we're generating the answer
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
receiver_ids=[user_id],
|
||||
@ -142,6 +187,178 @@ def handle_generate_answer_button(
|
||||
)
|
||||
|
||||
|
||||
def handle_publish_ephemeral_message_button(
|
||||
req: SocketModeRequest,
|
||||
client: TenantSocketModeClient,
|
||||
action_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
This function handles the Share with Everyone/Keep for Yourself buttons
|
||||
for ephemeral messages.
|
||||
"""
|
||||
channel_id = req.payload["channel"]["id"]
|
||||
ephemeral_message_ts = req.payload["container"]["message_ts"]
|
||||
|
||||
slack_sender_id = req.payload["user"]["id"]
|
||||
response_url = req.payload["response_url"]
|
||||
webhook = WebhookClient(url=response_url)
|
||||
|
||||
# The additional data required that was added to buttons.
|
||||
# Specifically, this contains the message_info, channel_conf information
|
||||
# and some additional attributes.
|
||||
value_dict = json.loads(req.payload["actions"][0]["value"])
|
||||
|
||||
original_question_ts = value_dict.get("original_question_ts")
|
||||
if not original_question_ts:
|
||||
raise ValueError("Missing original_question_ts in the payload")
|
||||
if not ephemeral_message_ts:
|
||||
raise ValueError("Missing ephemeral_message_ts in the payload")
|
||||
|
||||
feedback_reminder_id = value_dict.get("feedback_reminder_id")
|
||||
|
||||
slack_message_info = SlackMessageInfo(**value_dict["message_info"])
|
||||
channel_conf = value_dict.get("channel_conf")
|
||||
|
||||
user_email = value_dict.get("message_info", {}).get("email")
|
||||
|
||||
chat_message_id = value_dict.get("chat_message_id")
|
||||
|
||||
# Obtain onyx_user and chat_message information
|
||||
if not chat_message_id:
|
||||
raise ValueError("Missing chat_message_id in the payload")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
onyx_user = get_user_by_email(user_email, db_session)
|
||||
if not onyx_user:
|
||||
raise ValueError("Cannot determine onyx_user_id from email in payload")
|
||||
try:
|
||||
chat_message = get_chat_message(chat_message_id, onyx_user.id, db_session)
|
||||
except ValueError:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id, None, db_session
|
||||
) # is this good idea?
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get chat message: {e}")
|
||||
raise e
|
||||
|
||||
chat_message_detail = translate_db_message_to_chat_message_detail(chat_message)
|
||||
|
||||
# construct the proper citation format and then the answer in the suitable format
|
||||
# we need to construct the blocks.
|
||||
citation_list = _build_citation_list(chat_message_detail)
|
||||
|
||||
onyx_bot_answer = ChatOnyxBotResponse(
|
||||
answer=chat_message_detail.message,
|
||||
citations=citation_list,
|
||||
chat_message_id=chat_message_id,
|
||||
docs=QADocsResponse(
|
||||
top_documents=chat_message_detail.context_docs.top_documents
|
||||
if chat_message_detail.context_docs
|
||||
else [],
|
||||
predicted_flow=None,
|
||||
predicted_search=None,
|
||||
applied_source_filters=None,
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
llm_selected_doc_indices=None,
|
||||
error_msg=None,
|
||||
)
|
||||
|
||||
# Note: we need to use the webhook and the respond_url to update/delete ephemeral messages
|
||||
if action_id == SHOW_EVERYONE_ACTION_ID:
|
||||
# Convert to non-ephemeral message in thread
|
||||
try:
|
||||
webhook.send(
|
||||
response_type="ephemeral",
|
||||
text="",
|
||||
blocks=[],
|
||||
replace_original=True,
|
||||
delete_original=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook: {e}")
|
||||
|
||||
# remove handling of empheremal block and add AI feedback.
|
||||
all_blocks = build_slack_response_blocks(
|
||||
answer=onyx_bot_answer,
|
||||
message_info=slack_message_info,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=True,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
skip_ai_feedback=False,
|
||||
offer_ephemeral_publication=False,
|
||||
skip_restated_question=True,
|
||||
)
|
||||
try:
|
||||
# Post in thread as non-ephemeral message
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
receiver_ids=None, # If respond_member_group_list is set, send to them. TODO: check!
|
||||
text="Hello! Onyx has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=original_question_ts,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
send_as_ephemeral=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to publish ephemeral message: {e}")
|
||||
raise e
|
||||
|
||||
elif action_id == KEEP_TO_YOURSELF_ACTION_ID:
|
||||
# Keep as ephemeral message in channel or thread, but remove the publish button and add feedback button
|
||||
|
||||
changed_blocks = build_slack_response_blocks(
|
||||
answer=onyx_bot_answer,
|
||||
message_info=slack_message_info,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=True,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
skip_ai_feedback=False,
|
||||
offer_ephemeral_publication=False,
|
||||
skip_restated_question=True,
|
||||
)
|
||||
|
||||
try:
|
||||
if slack_message_info.thread_to_respond is not None:
|
||||
# There seems to be a bug in slack where an update within the thread
|
||||
# actually leads to the update to be posted in the channel. Therefore,
|
||||
# for now we delete the original ephemeral message and post a new one
|
||||
# if the ephemeral message is in a thread.
|
||||
webhook.send(
|
||||
response_type="ephemeral",
|
||||
text="",
|
||||
blocks=[],
|
||||
replace_original=True,
|
||||
delete_original=True,
|
||||
)
|
||||
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
receiver_ids=[slack_sender_id],
|
||||
text="Your personal response, sent as an ephemeral message.",
|
||||
blocks=changed_blocks,
|
||||
thread_ts=original_question_ts,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
send_as_ephemeral=True,
|
||||
)
|
||||
else:
|
||||
# This works fine if the ephemeral message is in the channel
|
||||
webhook.send(
|
||||
response_type="ephemeral",
|
||||
text="Your personal response, sent as an ephemeral message.",
|
||||
blocks=changed_blocks,
|
||||
replace_original=True,
|
||||
delete_original=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook: {e}")
|
||||
|
||||
|
||||
def handle_slack_feedback(
|
||||
feedback_id: str,
|
||||
feedback_type: str,
|
||||
@ -153,13 +370,20 @@ def handle_slack_feedback(
|
||||
) -> None:
|
||||
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
|
||||
|
||||
# Get Onyx user from Slack ID
|
||||
expert_info = expert_info_from_slack_id(
|
||||
user_id_to_post_confirmation, client, user_cache={}
|
||||
)
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
onyx_user = get_user_by_email(email, db_session) if email else None
|
||||
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
|
||||
create_chat_message_feedback(
|
||||
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
|
||||
feedback_text="",
|
||||
chat_message_id=message_id,
|
||||
user_id=None, # no "user" for Slack bot for now
|
||||
user_id=onyx_user.id if onyx_user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
remove_scheduled_feedback_reminder(
|
||||
@ -213,7 +437,7 @@ def handle_slack_feedback(
|
||||
else:
|
||||
msg = f"<@{user_id_to_post_confirmation}> has {feedback_response_txt} the AI Answer"
|
||||
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel_id_to_post_confirmation,
|
||||
text=msg,
|
||||
@ -232,7 +456,7 @@ def handle_followup_button(
|
||||
action_id = cast(str, action.get("block_id"))
|
||||
|
||||
channel_id = req.payload["container"]["channel_id"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
thread_ts = req.payload["container"].get("thread_ts", None)
|
||||
|
||||
update_emote_react(
|
||||
emoji=DANSWER_FOLLOWUP_EMOJI,
|
||||
@ -265,7 +489,7 @@ def handle_followup_button(
|
||||
|
||||
blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids)
|
||||
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
text="Received your request for more help",
|
||||
@ -315,7 +539,7 @@ def handle_followup_resolved_button(
|
||||
) -> None:
|
||||
channel_id = req.payload["container"]["channel_id"]
|
||||
message_ts = req.payload["container"]["message_ts"]
|
||||
thread_ts = req.payload["container"]["thread_ts"]
|
||||
thread_ts = req.payload["container"].get("thread_ts", None)
|
||||
|
||||
clicker_name = get_clicker_name(req, client)
|
||||
|
||||
@ -349,7 +573,7 @@ def handle_followup_resolved_button(
|
||||
|
||||
resolved_block = SectionBlock(text=msg_text)
|
||||
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=channel_id,
|
||||
text="Your request for help as been addressed!",
|
||||
|
@ -18,7 +18,7 @@ from onyx.onyxbot.slack.handlers.handle_standard_answers import (
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails
|
||||
from onyx.onyxbot.slack.utils import fetch_user_ids_from_groups
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import slack_usage_report
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.utils.logger import setup_logger
|
||||
@ -29,7 +29,7 @@ logger_base = setup_logger()
|
||||
|
||||
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
|
||||
if details.is_bot_msg and details.sender_id:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=details.channel_to_respond,
|
||||
thread_ts=details.msg_to_respond,
|
||||
@ -202,7 +202,7 @@ def handle_message(
|
||||
# which would just respond to the sender
|
||||
if send_to and is_bot_msg:
|
||||
if sender_id:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=[sender_id],
|
||||
@ -220,6 +220,7 @@ def handle_message(
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
# standard answers should be published in a thread
|
||||
used_standard_answer = handle_standard_answers(
|
||||
message_info=message_info,
|
||||
receiver_ids=send_to,
|
||||
|
@ -33,7 +33,7 @@ from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.handlers.utils import slackify_message_thread
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
@ -82,12 +82,38 @@ def handle_regular_answer(
|
||||
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
|
||||
# Capture whether response mode for channel is ephemeral. Even if the channel is set
|
||||
# to respond with an ephemeral message, we still send as non-ephemeral if
|
||||
# the message is a dm with the Onyx bot.
|
||||
send_as_ephemeral = (
|
||||
slack_channel_config.channel_config.get("is_ephemeral", False)
|
||||
and not message_info.is_bot_dm
|
||||
)
|
||||
|
||||
# If the channel mis configured to respond with an ephemeral message,
|
||||
# or the message is a dm to the Onyx bot, we should use the proper onyx user from the email.
|
||||
# This will make documents privately accessible to the user available to Onyx Bot answers.
|
||||
# Otherwise - if not ephemeral or DM to Onyx Bot - we must use None as the user to restrict
|
||||
# to public docs.
|
||||
|
||||
user = None
|
||||
if message_info.is_bot_dm:
|
||||
if message_info.is_bot_dm or send_as_ephemeral:
|
||||
if message_info.email:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
target_thread_ts = (
|
||||
None
|
||||
if send_as_ephemeral and len(message_info.thread_messages) < 2
|
||||
else message_ts_to_respond_to
|
||||
)
|
||||
target_receiver_ids = (
|
||||
[message_info.sender_id]
|
||||
if message_info.sender_id and send_as_ephemeral
|
||||
else receiver_ids
|
||||
)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
prompt = None
|
||||
# If no persona is specified, use the default search based persona
|
||||
@ -134,11 +160,10 @@ def handle_regular_answer(
|
||||
history_messages = messages[:-1]
|
||||
single_message_history = slackify_message_thread(history_messages) or None
|
||||
|
||||
# Always check for ACL permissions, also for documnt sets that were explicitly added
|
||||
# to the Bot by the Administrator. (Change relative to earlier behavior where all documents
|
||||
# in an attached document set were available to all users in the channel.)
|
||||
bypass_acl = False
|
||||
if slack_channel_config.persona and slack_channel_config.persona.document_sets:
|
||||
# For Slack channels, use the full document set, admin will be warned when configuring it
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
# if the message is not "/onyx" command, then it should have a message ts to respond to
|
||||
@ -219,12 +244,13 @@ def handle_regular_answer(
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
receiver_ids=target_receiver_ids,
|
||||
text=f"Encountered exception when trying to answer: \n\n```{e}```",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
|
||||
# In case of failures, don't keep the reaction there permanently
|
||||
@ -242,32 +268,36 @@ def handle_regular_answer(
|
||||
if answer is None:
|
||||
assert DISABLE_GENERATIVE_AI is True
|
||||
try:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=receiver_ids,
|
||||
receiver_ids=target_receiver_ids,
|
||||
text="Hello! Onyx has some results for you!",
|
||||
blocks=[
|
||||
SectionBlock(
|
||||
text="Onyx is down for maintenance.\nWe're working hard on recharging the AI!"
|
||||
)
|
||||
],
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
if receiver_ids:
|
||||
respond_in_thread(
|
||||
|
||||
# If the channel is ephemeral, we don't need to send a message to the user since they will already see the message
|
||||
if target_receiver_ids and not send_as_ephemeral:
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
"👋 Hi, we've just gathered and forwarded the relevant "
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
|
||||
return False
|
||||
@ -316,12 +346,13 @@ def handle_regular_answer(
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
receiver_ids=target_receiver_ids,
|
||||
text="Found no documents when trying to answer. Did you index any documents?",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
return True
|
||||
|
||||
@ -349,15 +380,27 @@ def handle_regular_answer(
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
if should_respond_with_error_msgs:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
receiver_ids=target_receiver_ids,
|
||||
text="Found no citations or quotes when trying to answer.",
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
return True
|
||||
|
||||
if (
|
||||
send_as_ephemeral
|
||||
and target_receiver_ids is not None
|
||||
and len(target_receiver_ids) == 1
|
||||
):
|
||||
offer_ephemeral_publication = True
|
||||
skip_ai_feedback = True
|
||||
else:
|
||||
offer_ephemeral_publication = False
|
||||
skip_ai_feedback = False if feedback_reminder_id else True
|
||||
|
||||
all_blocks = build_slack_response_blocks(
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
@ -365,31 +408,39 @@ def handle_regular_answer(
|
||||
use_citations=True, # No longer supporting quotes
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
expecting_search_result=expecting_search_result,
|
||||
offer_ephemeral_publication=offer_ephemeral_publication,
|
||||
skip_ai_feedback=skip_ai_feedback,
|
||||
)
|
||||
|
||||
try:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=[message_info.sender_id]
|
||||
if message_info.is_bot_msg and message_info.sender_id
|
||||
else receiver_ids,
|
||||
receiver_ids=target_receiver_ids,
|
||||
text="Hello! Onyx has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
|
||||
unfurl=False,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
|
||||
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
|
||||
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
|
||||
# if there is no message_ts_to_respond_to, and we have made it this far, then this is a /onyx message
|
||||
# so we shouldn't send_team_member_message
|
||||
if receiver_ids and message_ts_to_respond_to is not None:
|
||||
if (
|
||||
target_receiver_ids
|
||||
and message_ts_to_respond_to is not None
|
||||
and not send_as_ephemeral
|
||||
and target_thread_ts is not None
|
||||
):
|
||||
send_team_member_message(
|
||||
client=client,
|
||||
channel=channel,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
thread_ts=target_thread_ts,
|
||||
receiver_ids=target_receiver_ids,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
|
||||
return False
|
||||
|
@ -2,7 +2,7 @@ from slack_sdk import WebClient
|
||||
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
@ -32,8 +32,10 @@ def send_team_member_message(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
thread_ts: str,
|
||||
receiver_ids: list[str] | None = None,
|
||||
send_as_ephemeral: bool = False,
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
text=(
|
||||
@ -41,4 +43,6 @@ def send_team_member_message(
|
||||
+ "information to the team. They'll get back to you shortly!"
|
||||
),
|
||||
thread_ts=thread_ts,
|
||||
receiver_ids=None,
|
||||
send_as_ephemeral=send_as_ephemeral,
|
||||
)
|
||||
|
@ -57,7 +57,9 @@ from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID
|
||||
from onyx.onyxbot.slack.handlers.handle_buttons import handle_doc_feedback_button
|
||||
from onyx.onyxbot.slack.handlers.handle_buttons import handle_followup_button
|
||||
@ -67,6 +69,9 @@ from onyx.onyxbot.slack.handlers.handle_buttons import (
|
||||
from onyx.onyxbot.slack.handlers.handle_buttons import (
|
||||
handle_generate_answer_button,
|
||||
)
|
||||
from onyx.onyxbot.slack.handlers.handle_buttons import (
|
||||
handle_publish_ephemeral_message_button,
|
||||
)
|
||||
from onyx.onyxbot.slack.handlers.handle_buttons import handle_slack_feedback
|
||||
from onyx.onyxbot.slack.handlers.handle_message import handle_message
|
||||
from onyx.onyxbot.slack.handlers.handle_message import (
|
||||
@ -81,7 +86,7 @@ from onyx.onyxbot.slack.utils import get_onyx_bot_slack_bot_id
|
||||
from onyx.onyxbot.slack.utils import read_slack_thread
|
||||
from onyx.onyxbot.slack.utils import remove_onyx_bot_tag
|
||||
from onyx.onyxbot.slack.utils import rephrase_slack_message
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import TenantSocketModeClient
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.manage.models import SlackBotTokens
|
||||
@ -667,7 +672,11 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
|
||||
feedback_msg_reminder = cast(str, action.get("value"))
|
||||
feedback_id = cast(str, action.get("block_id"))
|
||||
channel_id = cast(str, req.payload["container"]["channel_id"])
|
||||
thread_ts = cast(str, req.payload["container"]["thread_ts"])
|
||||
thread_ts = cast(
|
||||
str,
|
||||
req.payload["container"].get("thread_ts")
|
||||
or req.payload["container"].get("message_ts"),
|
||||
)
|
||||
else:
|
||||
logger.error("Unable to process feedback. Action not found")
|
||||
return
|
||||
@ -783,7 +792,7 @@ def apologize_for_fail(
|
||||
details: SlackMessageInfo,
|
||||
client: TenantSocketModeClient,
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client.web_client,
|
||||
channel=details.channel_to_respond,
|
||||
thread_ts=details.msg_to_respond,
|
||||
@ -859,6 +868,14 @@ def action_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> No
|
||||
if action["action_id"] in [DISLIKE_BLOCK_ACTION_ID, LIKE_BLOCK_ACTION_ID]:
|
||||
# AI Answer feedback
|
||||
return process_feedback(req, client)
|
||||
elif action["action_id"] in [
|
||||
SHOW_EVERYONE_ACTION_ID,
|
||||
KEEP_TO_YOURSELF_ACTION_ID,
|
||||
]:
|
||||
# Publish ephemeral message or keep hidden in main channel
|
||||
return handle_publish_ephemeral_message_button(
|
||||
req, client, action["action_id"]
|
||||
)
|
||||
elif action["action_id"] == FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID:
|
||||
# Activation of the "source feedback" button
|
||||
return handle_doc_feedback_button(req, client)
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import ThreadMessage
|
||||
@ -13,3 +15,37 @@ class SlackMessageInfo(BaseModel):
|
||||
bypass_filters: bool # User has tagged @OnyxBot
|
||||
is_bot_msg: bool # User is using /OnyxBot
|
||||
is_bot_dm: bool # User is direct messaging to OnyxBot
|
||||
|
||||
|
||||
# Models used to encode the relevant data for the ephemeral message actions
|
||||
class ActionValuesEphemeralMessageMessageInfo(BaseModel):
|
||||
bypass_filters: bool | None
|
||||
channel_to_respond: str | None
|
||||
msg_to_respond: str | None
|
||||
email: str | None
|
||||
sender_id: str | None
|
||||
thread_messages: list[ThreadMessage] | None
|
||||
is_bot_msg: bool | None
|
||||
is_bot_dm: bool | None
|
||||
thread_to_respond: str | None
|
||||
|
||||
|
||||
class ActionValuesEphemeralMessageChannelConfig(BaseModel):
|
||||
channel_name: str | None
|
||||
respond_tag_only: bool | None
|
||||
respond_to_bots: bool | None
|
||||
is_ephemeral: bool
|
||||
respond_member_group_list: list[str] | None
|
||||
answer_filters: list[
|
||||
Literal["well_answered_postfilter", "questionmark_prefilter"]
|
||||
] | None
|
||||
follow_up_tags: list[str] | None
|
||||
show_continue_in_web_ui: bool
|
||||
|
||||
|
||||
class ActionValuesEphemeralMessage(BaseModel):
|
||||
original_question_ts: str | None
|
||||
feedback_reminder_id: str | None
|
||||
chat_message_id: int
|
||||
message_info: ActionValuesEphemeralMessageMessageInfo
|
||||
channel_conf: ActionValuesEphemeralMessageChannelConfig
|
||||
|
@ -184,7 +184,7 @@ def _build_error_block(error_message: str) -> Block:
|
||||
backoff=2,
|
||||
logger=cast(logging.Logger, logger),
|
||||
)
|
||||
def respond_in_thread(
|
||||
def respond_in_thread_or_channel(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
thread_ts: str | None,
|
||||
@ -193,6 +193,7 @@ def respond_in_thread(
|
||||
receiver_ids: list[str] | None = None,
|
||||
metadata: Metadata | None = None,
|
||||
unfurl: bool = True,
|
||||
send_as_ephemeral: bool | None = True,
|
||||
) -> list[str]:
|
||||
if not text and not blocks:
|
||||
raise ValueError("One of `text` or `blocks` must be provided")
|
||||
@ -236,6 +237,7 @@ def respond_in_thread(
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
|
||||
for receiver in receiver_ids:
|
||||
try:
|
||||
response = slack_call(
|
||||
@ -299,6 +301,12 @@ def build_feedback_id(
|
||||
return unique_prefix + ID_SEPARATOR + feedback_id
|
||||
|
||||
|
||||
def build_publish_ephemeral_message_id(
|
||||
original_question_ts: str,
|
||||
) -> str:
|
||||
return "publish_ephemeral_message__" + original_question_ts
|
||||
|
||||
|
||||
def build_continue_in_web_ui_id(
|
||||
message_id: int,
|
||||
) -> str:
|
||||
@ -539,7 +547,7 @@ def read_slack_thread(
|
||||
|
||||
# If auto-detected filters are on, use the second block for the actual answer
|
||||
# The first block is the auto-detected filters
|
||||
if message.startswith("_Filters"):
|
||||
if message is not None and message.startswith("_Filters"):
|
||||
if len(blocks) < 2:
|
||||
logger.warning(f"Only filter blocks found: {reply}")
|
||||
continue
|
||||
@ -611,7 +619,7 @@ class SlackRateLimiter:
|
||||
def notify(
|
||||
self, client: WebClient, channel: str, position: int, thread_ts: str | None
|
||||
) -> None:
|
||||
respond_in_thread(
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=None,
|
||||
|
@ -13,6 +13,7 @@ from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
|
||||
from onyx.db.credentials import delete_credential
|
||||
from onyx.db.credentials import delete_credential_for_user
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import fetch_credentials_by_source_for_user
|
||||
from onyx.db.credentials import fetch_credentials_for_user
|
||||
@ -88,7 +89,7 @@ def delete_credential_by_id_admin(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
"""Same as the user endpoint, but can delete any credential (not just the user's own)"""
|
||||
delete_credential(db_session=db_session, credential_id=credential_id, user=None)
|
||||
delete_credential(db_session=db_session, credential_id=credential_id)
|
||||
return StatusResponse(
|
||||
success=True, message="Credential deleted successfully", data=credential_id
|
||||
)
|
||||
@ -242,7 +243,7 @@ def delete_credential_by_id(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
delete_credential(
|
||||
delete_credential_for_user(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
@ -259,7 +260,7 @@ def force_delete_credential_by_id(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
delete_credential(credential_id, user, db_session, True)
|
||||
delete_credential_for_user(credential_id, user, db_session, True)
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Credential deleted successfully", data=credential_id
|
||||
|
@ -49,6 +49,7 @@ def get_folders(
|
||||
name=chat_session.description,
|
||||
persona_id=chat_session.persona_id,
|
||||
time_created=chat_session.time_created.isoformat(),
|
||||
time_updated=chat_session.time_updated.isoformat(),
|
||||
shared_status=chat_session.shared_status,
|
||||
folder_id=folder.id,
|
||||
)
|
||||
|
@ -181,6 +181,7 @@ class SlackChannelConfigCreationRequest(BaseModel):
|
||||
channel_name: str
|
||||
respond_tag_only: bool = False
|
||||
respond_to_bots: bool = False
|
||||
is_ephemeral: bool = False
|
||||
show_continue_in_web_ui: bool = False
|
||||
enable_auto_filters: bool = False
|
||||
# If no team members, assume respond in the channel to everyone
|
||||
|
@ -72,11 +72,13 @@ def set_new_search_settings(
|
||||
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
):
|
||||
index_name += ALT_INDEX_SUFFIX
|
||||
search_values = search_settings_new.dict()
|
||||
search_values = search_settings_new.model_dump()
|
||||
search_values["index_name"] = index_name
|
||||
new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
else:
|
||||
new_search_settings_request = SavedSearchSettings(**search_settings_new.dict())
|
||||
new_search_settings_request = SavedSearchSettings(
|
||||
**search_settings_new.model_dump()
|
||||
)
|
||||
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
@ -103,8 +105,10 @@ def set_new_search_settings(
|
||||
document_index = get_default_document_index(search_settings, new_search_settings)
|
||||
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=search_settings.model_dim,
|
||||
secondary_index_embedding_dim=new_search_settings.model_dim,
|
||||
primary_embedding_dim=search_settings.final_embedding_dim,
|
||||
primary_embedding_precision=search_settings.embedding_precision,
|
||||
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
secondary_index_embedding_precision=new_search_settings.embedding_precision,
|
||||
)
|
||||
|
||||
# Pause index attempts for the currently in use index to preserve resources
|
||||
@ -137,6 +141,17 @@ def cancel_new_embedding(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# remove the old index from the vector db
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(primary_search_settings, None)
|
||||
document_index.ensure_indices_exist(
|
||||
primary_embedding_dim=primary_search_settings.final_embedding_dim,
|
||||
primary_embedding_precision=primary_search_settings.embedding_precision,
|
||||
# just finished swap, no more secondary index
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/delete-search-settings")
|
||||
def delete_search_settings_endpoint(
|
||||
|
@ -71,6 +71,15 @@ def _form_channel_config(
|
||||
"also respond to a predetermined set of users."
|
||||
)
|
||||
|
||||
if (
|
||||
slack_channel_config_creation_request.is_ephemeral
|
||||
and slack_channel_config_creation_request.respond_member_group_list
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot set OnyxBot to respond to users in a private (ephemeral) message "
|
||||
"and also respond to a selected list of users."
|
||||
)
|
||||
|
||||
channel_config: ChannelConfig = {
|
||||
"channel_name": cleaned_channel_name,
|
||||
}
|
||||
@ -91,6 +100,8 @@ def _form_channel_config(
|
||||
"respond_to_bots"
|
||||
] = slack_channel_config_creation_request.respond_to_bots
|
||||
|
||||
channel_config["is_ephemeral"] = slack_channel_config_creation_request.is_ephemeral
|
||||
|
||||
channel_config["disabled"] = slack_channel_config_creation_request.disabled
|
||||
|
||||
return channel_config
|
||||
@ -343,7 +354,8 @@ def list_bot_configs(
|
||||
]
|
||||
|
||||
|
||||
MAX_CHANNELS = 200
|
||||
MAX_SLACK_PAGES = 5
|
||||
SLACK_API_CHANNELS_PER_PAGE = 100
|
||||
|
||||
|
||||
@router.get(
|
||||
@ -355,8 +367,8 @@ def get_all_channels_from_slack_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[SlackChannel]:
|
||||
"""
|
||||
Fetches all channels from the Slack API.
|
||||
If the workspace has 200 or more channels, we raise an error.
|
||||
Fetches channels the bot is a member of from the Slack API.
|
||||
Handles pagination with a limit to avoid excessive API calls.
|
||||
"""
|
||||
tokens = fetch_slack_bot_tokens(db_session, bot_id)
|
||||
if not tokens or "bot_token" not in tokens:
|
||||
@ -365,28 +377,60 @@ def get_all_channels_from_slack_api(
|
||||
)
|
||||
|
||||
client = WebClient(token=tokens["bot_token"])
|
||||
all_channels = []
|
||||
next_cursor = None
|
||||
current_page = 0
|
||||
|
||||
try:
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=MAX_CHANNELS,
|
||||
)
|
||||
# Use users_conversations with limited pagination
|
||||
while current_page < MAX_SLACK_PAGES:
|
||||
current_page += 1
|
||||
|
||||
# Make API call with cursor if we have one
|
||||
if next_cursor:
|
||||
response = client.users_conversations(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
cursor=next_cursor,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
)
|
||||
else:
|
||||
response = client.users_conversations(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
)
|
||||
|
||||
# Add channels to our list
|
||||
if "channels" in response and response["channels"]:
|
||||
all_channels.extend(response["channels"])
|
||||
|
||||
# Check if we need to paginate
|
||||
if (
|
||||
"response_metadata" in response
|
||||
and "next_cursor" in response["response_metadata"]
|
||||
):
|
||||
next_cursor = response["response_metadata"]["next_cursor"]
|
||||
if next_cursor:
|
||||
if current_page == MAX_SLACK_PAGES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Workspace has too many channels to paginate over in this call.",
|
||||
)
|
||||
continue
|
||||
|
||||
# If we get here, no more pages
|
||||
break
|
||||
|
||||
channels = [
|
||||
SlackChannel(id=channel["id"], name=channel["name"])
|
||||
for channel in response["channels"]
|
||||
for channel in all_channels
|
||||
]
|
||||
|
||||
if len(channels) == MAX_CHANNELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Workspace has {MAX_CHANNELS} or more channels.",
|
||||
)
|
||||
|
||||
return channels
|
||||
|
||||
except SlackApiError as e:
|
||||
# Handle rate limiting or other API errors
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error fetching channels from Slack API: {str(e)}",
|
||||
|
@ -147,9 +147,11 @@ def list_threads(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
|
@ -119,6 +119,7 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
|
@ -181,6 +181,7 @@ class ChatSessionDetails(BaseModel):
|
||||
name: str | None
|
||||
persona_id: int | None = None
|
||||
time_created: str
|
||||
time_updated: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
current_alternate_model: str | None = None
|
||||
@ -241,6 +242,7 @@ class ChatMessageDetail(BaseModel):
|
||||
files: list[FileDescriptor]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
refined_answer_improvement: bool | None = None
|
||||
is_agentic: bool | None = None
|
||||
error: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
|
@ -159,6 +159,7 @@ def get_user_search_sessions(
|
||||
name=sessions_with_documents_dict[search.id],
|
||||
persona_id=search.persona_id,
|
||||
time_created=search.time_created.isoformat(),
|
||||
time_updated=search.time_updated.isoformat(),
|
||||
shared_status=search.shared_status,
|
||||
folder_id=search.folder_id,
|
||||
current_alternate_model=search.current_alternate_model,
|
||||
|
@ -46,13 +46,21 @@ def mask_string(sensitive_str: str) -> str:
|
||||
return "****...**" + sensitive_str[-4:]
|
||||
|
||||
|
||||
MASK_CREDENTIALS_WHITELIST = {
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
"wiki_base",
|
||||
"cloud_name",
|
||||
"cloud_id",
|
||||
}
|
||||
|
||||
|
||||
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
||||
masked_creds = {}
|
||||
for key, val in credential_dict.items():
|
||||
if isinstance(val, str):
|
||||
# we want to pass the authentication_method field through so the frontend
|
||||
# can disambiguate credentials created by different methods
|
||||
if key == DB_CREDENTIALS_AUTHENTICATION_METHOD:
|
||||
if key in MASK_CREDENTIALS_WHITELIST:
|
||||
masked_creds[key] = val
|
||||
else:
|
||||
masked_creds[key] = mask_string(val)
|
||||
@ -63,8 +71,8 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to mask credentials of type other than string, cannot process request."
|
||||
f"Recieved type: {type(val)}"
|
||||
f"Unable to mask credentials of type other than string or int, cannot process request."
|
||||
f"Received type: {type(val)}"
|
||||
)
|
||||
|
||||
return masked_creds
|
||||
|
@ -21,6 +21,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import resync_cc_pair
|
||||
from onyx.db.credentials import create_initial_public_credential
|
||||
from onyx.db.document import check_docs_exist
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import expire_index_attempts
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
@ -32,7 +33,7 @@ from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.search_settings import update_secondary_search_settings
|
||||
from onyx.db.swap_index import check_index_swap
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
@ -73,7 +74,7 @@ def setup_onyx(
|
||||
|
||||
The Tenant Service calls the tenants/create endpoint which runs this.
|
||||
"""
|
||||
check_index_swap(db_session=db_session)
|
||||
check_and_perform_index_swap(db_session=db_session)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
search_settings = active_search_settings.primary
|
||||
@ -243,10 +244,18 @@ def setup_vespa(
|
||||
try:
|
||||
logger.notice(f"Setting up Vespa (attempt {x+1}/{num_attempts})...")
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=index_setting.model_dim,
|
||||
secondary_index_embedding_dim=secondary_index_setting.model_dim
|
||||
if secondary_index_setting
|
||||
else None,
|
||||
primary_embedding_dim=index_setting.final_embedding_dim,
|
||||
primary_embedding_precision=index_setting.embedding_precision,
|
||||
secondary_index_embedding_dim=(
|
||||
secondary_index_setting.final_embedding_dim
|
||||
if secondary_index_setting
|
||||
else None
|
||||
),
|
||||
secondary_index_embedding_precision=(
|
||||
secondary_index_setting.embedding_precision
|
||||
if secondary_index_setting
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
logger.notice("Vespa setup complete.")
|
||||
@ -360,6 +369,11 @@ def setup_vespa_multitenant(supported_indices: list[SupportedEmbeddingModel]) ->
|
||||
],
|
||||
embedding_dims=[index.dim for index in supported_indices]
|
||||
+ [index.dim for index in supported_indices],
|
||||
# on the cloud, just use float for all indices, the option to change this
|
||||
# is not exposed to the user
|
||||
embedding_precisions=[
|
||||
EmbeddingPrecision.FLOAT for _ in range(len(supported_indices) * 2)
|
||||
],
|
||||
)
|
||||
|
||||
logger.notice("Vespa setup complete.")
|
||||
|
@ -136,7 +136,7 @@ def seed_dummy_docs(
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
index_name = search_settings.index_name
|
||||
embedding_dim = search_settings.model_dim
|
||||
embedding_dim = search_settings.final_embedding_dim
|
||||
|
||||
vespa_index = VespaIndex(
|
||||
index_name=index_name,
|
||||
|
@ -30,6 +30,12 @@ class EmbedRequest(BaseModel):
|
||||
manual_passage_prefix: str | None = None
|
||||
api_url: str | None = None
|
||||
api_version: str | None = None
|
||||
|
||||
# allows for the truncation of the vector to a lower dimension
|
||||
# to reduce memory usage. Currently only supported for OpenAI models.
|
||||
# will be ignored for other providers.
|
||||
reduced_dimension: int | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
@ -5,7 +5,9 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
|
||||
from onyx.connectors.models import Document
|
||||
|
||||
|
||||
@ -18,12 +20,15 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
|
||||
)
|
||||
|
||||
connector.load_credentials(
|
||||
credentials_provider = OnyxStaticCredentialsProvider(
|
||||
None,
|
||||
DocumentSource.CONFLUENCE,
|
||||
{
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
}
|
||||
},
|
||||
)
|
||||
connector.set_credentials_provider(credentials_provider)
|
||||
return connector
|
||||
|
||||
|
||||
|
@ -2,7 +2,9 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -11,12 +13,16 @@ def confluence_connector() -> ConfluenceConnector:
|
||||
wiki_base="https://danswerai.atlassian.net",
|
||||
is_cloud=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
|
||||
credentials_provider = OnyxStaticCredentialsProvider(
|
||||
None,
|
||||
DocumentSource.CONFLUENCE,
|
||||
{
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
}
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
},
|
||||
)
|
||||
connector.set_credentials_provider(credentials_provider)
|
||||
return connector
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@ from onyx.db.engine import SYNC_DB_API
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.session import get_session_context_manager
|
||||
from onyx.db.session import get_session_with_tenant
|
||||
from onyx.db.swap_index import check_index_swap
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.db.tenant import get_all_tenant_ids
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||
@ -194,7 +194,7 @@ def reset_vespa() -> None:
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
# swap to the correct default model
|
||||
check_index_swap(db_session)
|
||||
check_and_perform_index_swap(db_session)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
@ -289,7 +289,7 @@ def reset_vespa_multitenant() -> None:
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# swap to the correct default model for each tenant
|
||||
check_index_swap(db_session)
|
||||
check_and_perform_index_swap(db_session)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
|
@ -108,3 +108,13 @@ def admin_user() -> DATestUser:
|
||||
@pytest.fixture
|
||||
def reset_multitenant() -> None:
|
||||
reset_all_multitenant()
|
||||
|
||||
|
||||
def pytest_runtest_logstart(nodeid: str, location: tuple[str, int | None, str]) -> None:
|
||||
print(f"\nTest start: {nodeid}")
|
||||
|
||||
|
||||
def pytest_runtest_logfinish(
|
||||
nodeid: str, location: tuple[str, int | None, str]
|
||||
) -> None:
|
||||
print(f"\nTest end: {nodeid}")
|
||||
|
@ -142,8 +142,12 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
selected_cc_pair = CCPairManager.get_indexing_status_by_id(
|
||||
cc_pair_1.id, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
assert selected_cc_pair is not None, "cc_pair not found after indexing!"
|
||||
assert selected_cc_pair.docs_indexed == 15
|
||||
|
||||
# used to be 15, but now
|
||||
# localhost:8889/ and localhost:8889/index.html are deduped
|
||||
assert selected_cc_pair.docs_indexed == 14
|
||||
|
||||
logger.info("Removing about.html.")
|
||||
os.remove(os.path.join(website_tgt, "about.html"))
|
||||
@ -160,24 +164,29 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
cc_pair_1.id, user_performing_action=admin_user
|
||||
)
|
||||
assert selected_cc_pair is not None, "cc_pair not found after pruning!"
|
||||
assert selected_cc_pair.docs_indexed == 13
|
||||
assert selected_cc_pair.docs_indexed == 12
|
||||
|
||||
# check vespa
|
||||
root_id = f"http://{hostname}:{port}/"
|
||||
index_id = f"http://{hostname}:{port}/index.html"
|
||||
about_id = f"http://{hostname}:{port}/about.html"
|
||||
courses_id = f"http://{hostname}:{port}/courses.html"
|
||||
|
||||
doc_ids = [index_id, about_id, courses_id]
|
||||
doc_ids = [root_id, index_id, about_id, courses_id]
|
||||
retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"]
|
||||
retrieved_docs = {
|
||||
doc["fields"]["document_id"]: doc["fields"]
|
||||
for doc in retrieved_docs_dict
|
||||
}
|
||||
|
||||
# verify index.html exists in Vespa
|
||||
retrieved_doc = retrieved_docs.get(index_id)
|
||||
# verify root exists in Vespa
|
||||
retrieved_doc = retrieved_docs.get(root_id)
|
||||
assert retrieved_doc
|
||||
|
||||
# verify index.html does not exist in Vespa since it is a duplicate of root
|
||||
retrieved_doc = retrieved_docs.get(index_id)
|
||||
assert not retrieved_doc
|
||||
|
||||
# verify about and courses do not exist
|
||||
retrieved_doc = retrieved_docs.get(about_id)
|
||||
assert not retrieved_doc
|
||||
|
@ -64,7 +64,7 @@ async def test_openai_embedding(
|
||||
|
||||
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
|
||||
result = await embedding._embed_openai(
|
||||
["test1", "test2"], "text-embedding-ada-002"
|
||||
["test1", "test2"], "text-embedding-ada-002", None
|
||||
)
|
||||
|
||||
assert result == sample_embeddings
|
||||
@ -89,6 +89,7 @@ async def test_embed_text_cloud_provider() -> None:
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
reduced_dimension=None,
|
||||
)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
@ -114,6 +115,7 @@ async def test_embed_text_local_model() -> None:
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
reduced_dimension=None,
|
||||
)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
@ -157,6 +159,7 @@ async def test_rate_limit_handling() -> None:
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
reduced_dimension=None,
|
||||
)
|
||||
|
||||
|
||||
@ -179,6 +182,7 @@ async def test_concurrent_embeddings() -> None:
|
||||
manual_passage_prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
reduced_dimension=None,
|
||||
)
|
||||
|
||||
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
|
||||
|
@ -3,9 +3,7 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
handle_confluence_rate_limit,
|
||||
)
|
||||
from onyx.connectors.confluence.utils import handle_confluence_rate_limit
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -50,6 +48,8 @@ def mock_confluence_call() -> Mock:
|
||||
# mock_sleep.assert_called_with(int(retry_after))
|
||||
|
||||
|
||||
# NOTE(rkuo): This tests an older version of rate limiting that is being deprecated
|
||||
# and probably should go away soon.
|
||||
def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
|
||||
mock_confluence_call.side_effect = HTTPError(
|
||||
response=Mock(status_code=500, text="Internal Server Error")
|
||||
|
114
web/package-lock.json
generated
114
web/package-lock.json
generated
@ -52,7 +52,7 @@
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.454.0",
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
"next": "^15.0.2",
|
||||
"next": "^15.2.0",
|
||||
"next-themes": "^0.4.4",
|
||||
"npm": "^10.8.0",
|
||||
"postcss": "^8.4.31",
|
||||
@ -2631,9 +2631,10 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "15.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-15.0.2.tgz",
|
||||
"integrity": "sha512-c0Zr0ModK5OX7D4ZV8Jt/wqoXtitLNPwUfG9zElCZztdaZyNVnN40rDXVZ/+FGuR4CcNV5AEfM6N8f+Ener7Dg=="
|
||||
"version": "15.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-15.2.0.tgz",
|
||||
"integrity": "sha512-eMgJu1RBXxxqqnuRJQh5RozhskoNUDHBFybvi+Z+yK9qzKeG7dadhv/Vp1YooSZmCnegf7JxWuapV77necLZNA==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@next/eslint-plugin-next": {
|
||||
"version": "14.2.3",
|
||||
@ -2645,12 +2646,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-arm64": {
|
||||
"version": "15.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.0.2.tgz",
|
||||
"integrity": "sha512-GK+8w88z+AFlmt+ondytZo2xpwlfAR8U6CRwXancHImh6EdGfHMIrTSCcx5sOSBei00GyLVL0ioo1JLKTfprgg==",
|
||||
"version": "15.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.2.0.tgz",
|
||||
"integrity": "sha512-rlp22GZwNJjFCyL7h5wz9vtpBVuCt3ZYjFWpEPBGzG712/uL1bbSkS675rVAUCRZ4hjoTJ26Q7IKhr5DfJrHDA==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
@ -2660,12 +2662,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-x64": {
|
||||
"version": "15.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.0.2.tgz",
|
||||
"integrity": "sha512-KUpBVxIbjzFiUZhiLIpJiBoelqzQtVZbdNNsehhUn36e2YzKHphnK8eTUW1s/4aPy5kH/UTid8IuVbaOpedhpw==",
|
||||
"version": "15.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.2.0.tgz",
|
||||
"integrity": "sha512-DiU85EqSHogCz80+sgsx90/ecygfCSGl5P3b4XDRVZpgujBm5lp4ts7YaHru7eVTyZMjHInzKr+w0/7+qDrvMA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"darwin"
|
||||
@ -2675,12 +2678,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||
"version": "15.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.0.2.tgz",
|
||||
"integrity": "sha512-9J7TPEcHNAZvwxXRzOtiUvwtTD+fmuY0l7RErf8Yyc7kMpE47MIQakl+3jecmkhOoIyi/Rp+ddq7j4wG6JDskQ==",
|
||||
"version": "15.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.2.0.tgz",
|
||||
"integrity": "sha512-VnpoMaGukiNWVxeqKHwi8MN47yKGyki5q+7ql/7p/3ifuU2341i/gDwGK1rivk0pVYbdv5D8z63uu9yMw0QhpQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
@ -2690,12 +2694,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-musl": {
|
||||
"version": "15.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.0.2.tgz",
|
||||
"integrity": "sha512-BjH4ZSzJIoTTZRh6rG+a/Ry4SW0HlizcPorqNBixBWc3wtQtj4Sn9FnRZe22QqrPnzoaW0ctvSz4FaH4eGKMww==",
|
||||
"version": "15.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.2.0.tgz",
|
||||
"integrity": "sha512-ka97/ssYE5nPH4Qs+8bd8RlYeNeUVBhcnsNUmFM6VWEob4jfN9FTr0NBhXVi1XEJpj3cMfgSRW+LdE3SUZbPrw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
@ -2705,12 +2710,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-gnu": {
|
||||
"version": "15.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.0.2.tgz",
|
||||
"integrity": "sha512-i3U2TcHgo26sIhcwX/Rshz6avM6nizrZPvrDVDY1bXcLH1ndjbO8zuC7RoHp0NSK7wjJMPYzm7NYL1ksSKFreA==",
|
||||
"version": "15.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.2.0.tgz",
|
||||
"integrity": "sha512-zY1JduE4B3q0k2ZCE+DAF/1efjTXUsKP+VXRtrt/rJCTgDlUyyryx7aOgYXNc1d8gobys/Lof9P9ze8IyRDn7Q==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
@ -2720,12 +2726,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-musl": {
|
||||
"version": "15.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.0.2.tgz",
|
||||
"integrity": "sha512-AMfZfSVOIR8fa+TXlAooByEF4OB00wqnms1sJ1v+iu8ivwvtPvnkwdzzFMpsK5jA2S9oNeeQ04egIWVb4QWmtQ==",
|
||||
"version": "15.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.2.0.tgz",
|
||||
"integrity": "sha512-QqvLZpurBD46RhaVaVBepkVQzh8xtlUN00RlG4Iq1sBheNugamUNPuZEH1r9X1YGQo1KqAe1iiShF0acva3jHQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"linux"
|
||||
@ -2735,12 +2742,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||
"version": "15.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.0.2.tgz",
|
||||
"integrity": "sha512-JkXysDT0/hEY47O+Hvs8PbZAeiCQVxKfGtr4GUpNAhlG2E0Mkjibuo8ryGD29Qb5a3IOnKYNoZlh/MyKd2Nbww==",
|
||||
"version": "15.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.2.0.tgz",
|
||||
"integrity": "sha512-ODZ0r9WMyylTHAN6pLtvUtQlGXBL9voljv6ujSlcsjOxhtXPI1Ag6AhZK0SE8hEpR1374WZZ5w33ChpJd5fsjw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
@ -2750,12 +2758,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-x64-msvc": {
|
||||
"version": "15.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.0.2.tgz",
|
||||
"integrity": "sha512-foaUL0NqJY/dX0Pi/UcZm5zsmSk5MtP/gxx3xOPyREkMFN+CTjctPfu3QaqrQHinaKdPnMWPJDKt4VjDfTBe/Q==",
|
||||
"version": "15.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.2.0.tgz",
|
||||
"integrity": "sha512-8+4Z3Z7xa13NdUuUAcpVNA6o76lNPniBd9Xbo02bwXQXnZgFvEopwY2at5+z7yHl47X9qbZpvwatZ2BRo3EdZw==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
"license": "MIT",
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
@ -7389,11 +7398,12 @@
|
||||
"integrity": "sha512-e2BR4lsJkkRlKZ/qCHPw9ZaSxc0MVUd7gtbtaB7aMvHeJVYe8sOB8DBZkP2DtISHGSku9sCK6T6cnY0CtXrOCQ=="
|
||||
},
|
||||
"node_modules/@swc/helpers": {
|
||||
"version": "0.5.13",
|
||||
"resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.13.tgz",
|
||||
"integrity": "sha512-UoKGxQ3r5kYI9dALKJapMmuK+1zWM/H17Z1+iwnNmzcJRnfFuevZs375TA5rW31pu4BS4NoSy1fRsexDXfWn5w==",
|
||||
"version": "0.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.15.tgz",
|
||||
"integrity": "sha512-JQ5TuMi45Owi4/BIMAJBoSQoOJu12oOk/gADqlcUL9JEdHB8vyjUSsxqeNXnmXHjYKMi2WcYtezGEEhqUI/E2g==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"tslib": "^2.4.0"
|
||||
"tslib": "^2.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tailwindcss/typography": {
|
||||
@ -14966,13 +14976,14 @@
|
||||
"integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw=="
|
||||
},
|
||||
"node_modules/next": {
|
||||
"version": "15.0.2",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-15.0.2.tgz",
|
||||
"integrity": "sha512-rxIWHcAu4gGSDmwsELXacqAPUk+j8dV/A9cDF5fsiCMpkBDYkO2AEaL1dfD+nNmDiU6QMCFN8Q30VEKapT9UHQ==",
|
||||
"version": "15.2.0",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-15.2.0.tgz",
|
||||
"integrity": "sha512-VaiM7sZYX8KIAHBrRGSFytKknkrexNfGb8GlG6e93JqueCspuGte8i4ybn8z4ww1x3f2uzY4YpTaBEW4/hvsoQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@next/env": "15.0.2",
|
||||
"@next/env": "15.2.0",
|
||||
"@swc/counter": "0.1.3",
|
||||
"@swc/helpers": "0.5.13",
|
||||
"@swc/helpers": "0.5.15",
|
||||
"busboy": "1.6.0",
|
||||
"caniuse-lite": "^1.0.30001579",
|
||||
"postcss": "8.4.31",
|
||||
@ -14982,25 +14993,25 @@
|
||||
"next": "dist/bin/next"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.18.0"
|
||||
"node": "^18.18.0 || ^19.8.0 || >= 20.0.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@next/swc-darwin-arm64": "15.0.2",
|
||||
"@next/swc-darwin-x64": "15.0.2",
|
||||
"@next/swc-linux-arm64-gnu": "15.0.2",
|
||||
"@next/swc-linux-arm64-musl": "15.0.2",
|
||||
"@next/swc-linux-x64-gnu": "15.0.2",
|
||||
"@next/swc-linux-x64-musl": "15.0.2",
|
||||
"@next/swc-win32-arm64-msvc": "15.0.2",
|
||||
"@next/swc-win32-x64-msvc": "15.0.2",
|
||||
"@next/swc-darwin-arm64": "15.2.0",
|
||||
"@next/swc-darwin-x64": "15.2.0",
|
||||
"@next/swc-linux-arm64-gnu": "15.2.0",
|
||||
"@next/swc-linux-arm64-musl": "15.2.0",
|
||||
"@next/swc-linux-x64-gnu": "15.2.0",
|
||||
"@next/swc-linux-x64-musl": "15.2.0",
|
||||
"@next/swc-win32-arm64-msvc": "15.2.0",
|
||||
"@next/swc-win32-x64-msvc": "15.2.0",
|
||||
"sharp": "^0.33.5"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@opentelemetry/api": "^1.1.0",
|
||||
"@playwright/test": "^1.41.2",
|
||||
"babel-plugin-react-compiler": "*",
|
||||
"react": "^18.2.0 || 19.0.0-rc-02c0e824-20241028",
|
||||
"react-dom": "^18.2.0 || 19.0.0-rc-02c0e824-20241028",
|
||||
"react": "^18.2.0 || 19.0.0-rc-de68d2f4-20241204 || ^19.0.0",
|
||||
"react-dom": "^18.2.0 || 19.0.0-rc-de68d2f4-20241204 || ^19.0.0",
|
||||
"sass": "^1.3.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
@ -20590,9 +20601,10 @@
|
||||
}
|
||||
},
|
||||
"node_modules/tslib": {
|
||||
"version": "2.6.2",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
|
||||
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q=="
|
||||
"version": "2.8.1",
|
||||
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz",
|
||||
"integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==",
|
||||
"license": "0BSD"
|
||||
},
|
||||
"node_modules/type-check": {
|
||||
"version": "0.4.0",
|
||||
|
@ -55,7 +55,7 @@
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.454.0",
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
"next": "^15.0.2",
|
||||
"next": "^15.2.0",
|
||||
"next-themes": "^0.4.4",
|
||||
"npm": "^10.8.0",
|
||||
"postcss": "^8.4.31",
|
||||
|
@ -100,11 +100,6 @@ export default function LabelManagement() {
|
||||
width="w-full max-w-xs"
|
||||
name={`editLabelName_${label.id}`}
|
||||
label="Label Name"
|
||||
value={
|
||||
values.editLabelId === label.id
|
||||
? values.editLabelName
|
||||
: label.name
|
||||
}
|
||||
onChange={(e) => {
|
||||
setFieldValue("editLabelId", label.id);
|
||||
setFieldValue("editLabelName", e.target.value);
|
||||
|
@ -163,7 +163,7 @@ export function PersonasTable() {
|
||||
{popup}
|
||||
{deleteModalOpen && personaToDelete && (
|
||||
<ConfirmEntityModal
|
||||
entityType="Persona"
|
||||
entityType="Assistant"
|
||||
entityName={personaToDelete.name}
|
||||
onClose={closeDeleteModal}
|
||||
onSubmit={handleDeletePersona}
|
||||
|
@ -52,7 +52,6 @@ export default function StarterMessagesList({
|
||||
<TextFormField
|
||||
name={`starter_messages.${index}.message`}
|
||||
label=""
|
||||
value={starterMessage.message}
|
||||
onChange={(e) => handleInputChange(index, e.target.value)}
|
||||
className="flex-grow"
|
||||
removeLabel
|
||||
|
@ -83,6 +83,8 @@ export const SlackChannelConfigCreationForm = ({
|
||||
respond_tag_only:
|
||||
existingSlackChannelConfig?.channel_config?.respond_tag_only ||
|
||||
false,
|
||||
is_ephemeral:
|
||||
existingSlackChannelConfig?.channel_config?.is_ephemeral || false,
|
||||
respond_to_bots:
|
||||
existingSlackChannelConfig?.channel_config?.respond_to_bots ||
|
||||
false,
|
||||
@ -135,6 +137,7 @@ export const SlackChannelConfigCreationForm = ({
|
||||
questionmark_prefilter_enabled: Yup.boolean().required(),
|
||||
respond_tag_only: Yup.boolean().required(),
|
||||
respond_to_bots: Yup.boolean().required(),
|
||||
is_ephemeral: Yup.boolean().required(),
|
||||
show_continue_in_web_ui: Yup.boolean().required(),
|
||||
enable_auto_filters: Yup.boolean().required(),
|
||||
respond_member_group_list: Yup.array().of(Yup.string()).required(),
|
||||
|
@ -199,17 +199,17 @@ export function SlackChannelConfigFormFields({
|
||||
<Badge variant="agent" className="bg-blue-100 text-blue-800">
|
||||
Default Configuration
|
||||
</Badge>
|
||||
<p className="mt-2 text-sm text-gray-600">
|
||||
<p className="mt-2 text-sm text-neutral-600">
|
||||
This default configuration will apply across all Slack channels
|
||||
the bot is added to in the Slack workspace, as well as direct
|
||||
messages (DMs), unless disabled.
|
||||
</p>
|
||||
<div className="mt-4 p-4 bg-gray-100 rounded-md border border-gray-300">
|
||||
<div className="mt-4 p-4 bg-neutral-100 rounded-md border border-neutral-300">
|
||||
<CheckFormField
|
||||
name="disabled"
|
||||
label="Disable Default Configuration"
|
||||
/>
|
||||
<p className="mt-2 text-sm text-gray-600 italic">
|
||||
<p className="mt-2 text-sm text-neutral-600 italic">
|
||||
Warning: Disabling the default configuration means the bot
|
||||
won't respond in Slack channels or DMs unless explicitly
|
||||
configured for them.
|
||||
@ -238,20 +238,28 @@ export function SlackChannelConfigFormFields({
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<Field name="channel_name">
|
||||
{({ field, form }: { field: any; form: any }) => (
|
||||
<SearchMultiSelectDropdown
|
||||
options={channelOptions || []}
|
||||
onSelect={(selected) => {
|
||||
form.setFieldValue("channel_name", selected.name);
|
||||
}}
|
||||
initialSearchTerm={field.value}
|
||||
onSearchTermChange={(term) => {
|
||||
form.setFieldValue("channel_name", term);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
<>
|
||||
<Field name="channel_name">
|
||||
{({ field, form }: { field: any; form: any }) => (
|
||||
<SearchMultiSelectDropdown
|
||||
options={channelOptions || []}
|
||||
onSelect={(selected) => {
|
||||
form.setFieldValue("channel_name", selected.name);
|
||||
}}
|
||||
initialSearchTerm={field.value}
|
||||
onSearchTermChange={(term) => {
|
||||
form.setFieldValue("channel_name", term);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
<p className="mt-2 text-sm dark:text-neutral-400 text-neutral-600">
|
||||
Note: This list shows public and private channels where the
|
||||
bot is a member (up to 500 channels). If you don't see a
|
||||
channel, make sure the bot is added to that channel in Slack
|
||||
first, or type the channel name manually.
|
||||
</p>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
@ -589,6 +597,13 @@ export function SlackChannelConfigFormFields({
|
||||
label="Respond to Bot messages"
|
||||
tooltip="If not set, OnyxBot will always ignore messages from Bots"
|
||||
/>
|
||||
<CheckFormField
|
||||
name="is_ephemeral"
|
||||
label="Respond to user in a private (ephemeral) message"
|
||||
tooltip="If set, OnyxBot will respond only to the user in a private (ephemeral) message. If you also
|
||||
chose 'Search' Assistant above, selecting this option will make documents that are private to the user
|
||||
available for their queries."
|
||||
/>
|
||||
|
||||
<TextArrayField
|
||||
name="respond_member_group_list"
|
||||
@ -627,11 +642,14 @@ export function SlackChannelConfigFormFields({
|
||||
Privacy Alert
|
||||
</Label>
|
||||
<p className="text-sm text-text-darker mb-4">
|
||||
Please note that at least one of the documents accessible by
|
||||
your OnyxBot is marked as private and may contain sensitive
|
||||
information. These documents will be accessible to all users
|
||||
of this OnyxBot. Ensure this aligns with your intended
|
||||
document sharing policy.
|
||||
Please note that if the private (ephemeral) response is *not
|
||||
selected*, only public documents within the selected document
|
||||
sets will be accessible for user queries. If the private
|
||||
(ephemeral) response *is selected*, user quries can also
|
||||
leverage documents that the user has already been granted
|
||||
access to. Note that users will be able to share the response
|
||||
with others in the channel, so please ensure that this is
|
||||
aligned with your company sharing policies.
|
||||
</p>
|
||||
<div className="space-y-2">
|
||||
<h4 className="text-sm text-text font-medium">
|
||||
|
@ -14,6 +14,7 @@ interface SlackChannelConfigCreationRequest {
|
||||
answer_validity_check_enabled: boolean;
|
||||
questionmark_prefilter_enabled: boolean;
|
||||
respond_tag_only: boolean;
|
||||
is_ephemeral: boolean;
|
||||
respond_to_bots: boolean;
|
||||
show_continue_in_web_ui: boolean;
|
||||
respond_member_group_list: string[];
|
||||
@ -45,6 +46,7 @@ const buildRequestBodyFromCreationRequest = (
|
||||
channel_name: creationRequest.channel_name,
|
||||
respond_tag_only: creationRequest.respond_tag_only,
|
||||
respond_to_bots: creationRequest.respond_to_bots,
|
||||
is_ephemeral: creationRequest.is_ephemeral,
|
||||
show_continue_in_web_ui: creationRequest.show_continue_in_web_ui,
|
||||
enable_auto_filters: creationRequest.enable_auto_filters,
|
||||
respond_member_group_list: creationRequest.respond_member_group_list,
|
||||
|
@ -71,7 +71,7 @@ function Main() {
|
||||
<p className="text-text-600">
|
||||
Learn more about Unstructured{" "}
|
||||
<a
|
||||
href="https://unstructured.io/docs"
|
||||
href="https://docs.unstructured.io/welcome"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-blue-500 hover:underline font-medium"
|
||||
|
@ -108,15 +108,13 @@ export default function UpgradingPage({
|
||||
>
|
||||
<div>
|
||||
<div>
|
||||
Are you sure you want to cancel?
|
||||
<br />
|
||||
<br />
|
||||
Cancelling will revert to the previous model and all progress will
|
||||
be lost.
|
||||
Are you sure you want to cancel? Cancelling will revert to the
|
||||
previous model and all progress will be lost.
|
||||
</div>
|
||||
<div className="flex">
|
||||
<Button onClick={onCancel} variant="submit">
|
||||
Confirm
|
||||
<div className="mt-12 gap-x-2 w-full justify-end flex">
|
||||
<Button onClick={onCancel}>Confirm</Button>
|
||||
<Button onClick={() => setIsCancelling(false)} variant="outline">
|
||||
Cancel
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
@ -141,30 +139,46 @@ export default function UpgradingPage({
|
||||
</Button>
|
||||
|
||||
{connectors && connectors.length > 0 ? (
|
||||
<>
|
||||
{failedIndexingStatus && failedIndexingStatus.length > 0 && (
|
||||
<FailedReIndexAttempts
|
||||
failedIndexingStatuses={failedIndexingStatus}
|
||||
setPopup={setPopup}
|
||||
/>
|
||||
)}
|
||||
futureEmbeddingModel.background_reindex_enabled ? (
|
||||
<>
|
||||
{failedIndexingStatus && failedIndexingStatus.length > 0 && (
|
||||
<FailedReIndexAttempts
|
||||
failedIndexingStatuses={failedIndexingStatus}
|
||||
setPopup={setPopup}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Text className="my-4">
|
||||
The table below shows the re-indexing progress of all existing
|
||||
connectors. Once all connectors have been re-indexed
|
||||
successfully, the new model will be used for all search
|
||||
queries. Until then, we will use the old model so that no
|
||||
downtime is necessary during this transition.
|
||||
</Text>
|
||||
<Text className="my-4">
|
||||
The table below shows the re-indexing progress of all
|
||||
existing connectors. Once all connectors have been
|
||||
re-indexed successfully, the new model will be used for all
|
||||
search queries. Until then, we will use the old model so
|
||||
that no downtime is necessary during this transition.
|
||||
</Text>
|
||||
|
||||
{sortedReindexingProgress ? (
|
||||
<ReindexingProgressTable
|
||||
reindexingProgress={sortedReindexingProgress}
|
||||
/>
|
||||
) : (
|
||||
<ErrorCallout errorTitle="Failed to fetch reindexing progress" />
|
||||
)}
|
||||
</>
|
||||
{sortedReindexingProgress ? (
|
||||
<ReindexingProgressTable
|
||||
reindexingProgress={sortedReindexingProgress}
|
||||
/>
|
||||
) : (
|
||||
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
<div className="mt-8">
|
||||
<h3 className="text-lg font-semibold mb-2">
|
||||
Switching Embedding Models
|
||||
</h3>
|
||||
<p className="mb-4 text-text-800">
|
||||
You're currently switching embedding models, and
|
||||
you've selected the instant switch option. The
|
||||
transition will complete shortly.
|
||||
</p>
|
||||
<p className="text-text-600">
|
||||
The new model will be active soon.
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
) : (
|
||||
<div className="mt-8 p-6 bg-background-100 border border-border-strong rounded-lg max-w-2xl">
|
||||
<h3 className="text-lg font-semibold mb-2">
|
||||
|
@ -455,15 +455,15 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
<Title>Indexing Attempts</Title>
|
||||
</div>
|
||||
{indexAttemptErrors && indexAttemptErrors.total_items > 0 && (
|
||||
<Alert className="border-alert bg-yellow-50 my-2">
|
||||
<AlertCircle className="h-4 w-4 text-yellow-700" />
|
||||
<AlertTitle className="text-yellow-950 font-semibold">
|
||||
<Alert className="border-alert bg-yellow-50 dark:bg-yellow-800 my-2">
|
||||
<AlertCircle className="h-4 w-4 text-yellow-700 dark:text-yellow-500" />
|
||||
<AlertTitle className="text-yellow-950 dark:text-yellow-200 font-semibold">
|
||||
Some documents failed to index
|
||||
</AlertTitle>
|
||||
<AlertDescription className="text-yellow-900">
|
||||
<AlertDescription className="text-yellow-900 dark:text-yellow-300">
|
||||
{isResolvingErrors ? (
|
||||
<span>
|
||||
<span className="text-sm text-yellow-700 animate-pulse">
|
||||
<span className="text-sm text-yellow-700 dark:text-yellow-400 da animate-pulse">
|
||||
Resolving failures
|
||||
</span>
|
||||
</span>
|
||||
@ -471,7 +471,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
<>
|
||||
We ran into some issues while processing some documents.{" "}
|
||||
<b
|
||||
className="text-link cursor-pointer"
|
||||
className="text-link cursor-pointer dark:text-blue-300"
|
||||
onClick={() => setShowIndexAttemptErrors(true)}
|
||||
>
|
||||
View details.
|
||||
|
@ -193,13 +193,15 @@ export default function AddConnector({
|
||||
// Check if there are no credentials
|
||||
const noCredentials = credentialTemplate == null;
|
||||
|
||||
if (noCredentials && 1 != formStep) {
|
||||
setFormStep(Math.max(1, formStep));
|
||||
}
|
||||
useEffect(() => {
|
||||
if (noCredentials && 1 != formStep) {
|
||||
setFormStep(Math.max(1, formStep));
|
||||
}
|
||||
|
||||
if (!noCredentials && !credentialActivated && formStep != 0) {
|
||||
setFormStep(Math.min(formStep, 0));
|
||||
}
|
||||
if (!noCredentials && !credentialActivated && formStep != 0) {
|
||||
setFormStep(Math.min(formStep, 0));
|
||||
}
|
||||
}, [noCredentials, formStep, setFormStep]);
|
||||
|
||||
const convertStringToDateTime = (indexingStart: string | null) => {
|
||||
return indexingStart ? new Date(indexingStart) : null;
|
||||
|
@ -33,7 +33,7 @@ export default function OAuthCallbackPage() {
|
||||
const connector = pathname?.split("/")[3];
|
||||
|
||||
useEffect(() => {
|
||||
const handleOAuthCallback = async () => {
|
||||
const onFirstLoad = async () => {
|
||||
// Examples
|
||||
// connector (url segment)= "google-drive"
|
||||
// sourceType (for looking up metadata) = "google_drive"
|
||||
@ -85,10 +85,19 @@ export default function OAuthCallbackPage() {
|
||||
}
|
||||
|
||||
setStatusMessage("Success!");
|
||||
setStatusDetails(
|
||||
`Your authorization with ${sourceMetadata.displayName} completed successfully.`
|
||||
);
|
||||
setRedirectUrl(response.redirect_on_success); // Extract the redirect URL
|
||||
|
||||
// set the continuation link
|
||||
if (response.finalize_url) {
|
||||
setRedirectUrl(response.finalize_url);
|
||||
setStatusDetails(
|
||||
`Your authorization with ${sourceMetadata.displayName} completed successfully. Additional steps are required to complete credential setup.`
|
||||
);
|
||||
} else {
|
||||
setRedirectUrl(response.redirect_on_success);
|
||||
setStatusDetails(
|
||||
`Your authorization with ${sourceMetadata.displayName} completed successfully.`
|
||||
);
|
||||
}
|
||||
setIsError(false);
|
||||
} catch (error) {
|
||||
console.error("OAuth error:", error);
|
||||
@ -100,15 +109,15 @@ export default function OAuthCallbackPage() {
|
||||
}
|
||||
};
|
||||
|
||||
handleOAuthCallback();
|
||||
onFirstLoad();
|
||||
}, [code, state, connector]);
|
||||
|
||||
return (
|
||||
<div className="container mx-auto py-8">
|
||||
<div className="mx-auto h-screen flex flex-col">
|
||||
<AdminPageTitle title={pageTitle} icon={<KeyIcon size={32} />} />
|
||||
|
||||
<div className="flex flex-col items-center justify-center min-h-screen">
|
||||
<CardSection className="max-w-md">
|
||||
<div className="flex-1 flex flex-col items-center justify-center">
|
||||
<CardSection className="max-w-md w-[500px] h-[250px] p-8">
|
||||
<h1 className="text-2xl font-bold mb-4">{statusMessage}</h1>
|
||||
<p className="text-text-500">{statusDetails}</p>
|
||||
{redirectUrl && !isError && (
|
||||
|
293
web/src/app/admin/connectors/[connector]/oauth/finalize/page.tsx
Normal file
293
web/src/app/admin/connectors/[connector]/oauth/finalize/page.tsx
Normal file
@ -0,0 +1,293 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import Title from "@/components/ui/title";
|
||||
import { KeyIcon } from "@/components/icons/icons";
|
||||
import { getSourceMetadata, isValidSource } from "@/lib/sources";
|
||||
import { ConfluenceAccessibleResource, ValidSources } from "@/lib/types";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import {
|
||||
handleOAuthAuthorizationResponse,
|
||||
handleOAuthConfluenceFinalize,
|
||||
handleOAuthPrepareFinalization,
|
||||
} from "@/lib/oauth_utils";
|
||||
import { SelectorFormField } from "@/components/admin/connectors/Field";
|
||||
import { ErrorMessage, Field, Form, Formik, useFormikContext } from "formik";
|
||||
import * as Yup from "yup";
|
||||
|
||||
// Helper component to keep the effect logic clean:
|
||||
function UpdateCloudURLOnCloudIdChange({
|
||||
accessibleResources,
|
||||
}: {
|
||||
accessibleResources: ConfluenceAccessibleResource[];
|
||||
}) {
|
||||
const { values, setValues, setFieldValue } = useFormikContext<{
|
||||
cloud_id: string;
|
||||
cloud_name: string;
|
||||
cloud_url: string;
|
||||
}>();
|
||||
|
||||
useEffect(() => {
|
||||
// Whenever cloud_id changes, find the matching resource and update cloud_url
|
||||
if (values.cloud_id) {
|
||||
const selectedResource = accessibleResources.find(
|
||||
(resource) => resource.id === values.cloud_id
|
||||
);
|
||||
if (selectedResource) {
|
||||
// Update multiple fields together ... somehow setting them in sequence
|
||||
// doesn't work with the validator
|
||||
// it may also be possible to await each setFieldValue call.
|
||||
// https://github.com/jaredpalmer/formik/issues/2266
|
||||
setValues((prevValues) => ({
|
||||
...prevValues,
|
||||
cloud_name: selectedResource.name,
|
||||
cloud_url: selectedResource.url,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}, [values.cloud_id, accessibleResources, setFieldValue]);
|
||||
|
||||
// This component doesn't render anything visible:
|
||||
return null;
|
||||
}
|
||||
|
||||
export default function OAuthFinalizePage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
const [statusMessage, setStatusMessage] = useState("Processing...");
|
||||
const [statusDetails, setStatusDetails] = useState(
|
||||
"Please wait while we complete the setup."
|
||||
);
|
||||
const [redirectUrl, setRedirectUrl] = useState<string | null>(null);
|
||||
const [isError, setIsError] = useState(false);
|
||||
const [isSubmitted, setIsSubmitted] = useState(false); // New state
|
||||
const [pageTitle, setPageTitle] = useState(
|
||||
"Finalize Authorization with Third-Party service"
|
||||
);
|
||||
|
||||
const [accessibleResources, setAccessibleResources] = useState<
|
||||
ConfluenceAccessibleResource[]
|
||||
>([]);
|
||||
|
||||
// Extract query parameters
|
||||
const credentialParam = searchParams.get("credential");
|
||||
const credential = credentialParam ? parseInt(credentialParam, 10) : NaN;
|
||||
const pathname = usePathname();
|
||||
const connector = pathname?.split("/")[3];
|
||||
|
||||
useEffect(() => {
|
||||
const onFirstLoad = async () => {
|
||||
// Examples
|
||||
// connector (url segment)= "google-drive"
|
||||
// sourceType (for looking up metadata) = "google_drive"
|
||||
|
||||
if (isNaN(credential)) {
|
||||
setStatusMessage("Improperly formed OAuth finalization request.");
|
||||
setStatusDetails("Invalid or missing credential id.");
|
||||
setIsError(true);
|
||||
return;
|
||||
}
|
||||
|
||||
const sourceType = connector.replaceAll("-", "_");
|
||||
if (!isValidSource(sourceType)) {
|
||||
setStatusMessage(
|
||||
`The specified connector source type ${sourceType} does not exist.`
|
||||
);
|
||||
setStatusDetails(`${sourceType} is not a valid source type.`);
|
||||
setIsError(true);
|
||||
return;
|
||||
}
|
||||
|
||||
const sourceMetadata = getSourceMetadata(sourceType as ValidSources);
|
||||
setPageTitle(`Finalize Authorization with ${sourceMetadata.displayName}`);
|
||||
|
||||
setStatusMessage("Processing...");
|
||||
setStatusDetails(
|
||||
"Please wait while we retrieve a list of your accessible sites."
|
||||
);
|
||||
setIsError(false); // Ensure no error state during loading
|
||||
|
||||
try {
|
||||
const response = await handleOAuthPrepareFinalization(
|
||||
connector,
|
||||
credential
|
||||
);
|
||||
|
||||
if (!response) {
|
||||
throw new Error("Empty response from OAuth server.");
|
||||
}
|
||||
|
||||
setAccessibleResources(response.accessible_resources);
|
||||
|
||||
setStatusMessage("Select a Confluence site");
|
||||
setStatusDetails("");
|
||||
|
||||
setIsError(false);
|
||||
} catch (error) {
|
||||
console.error("OAuth finalization error:", error);
|
||||
setStatusMessage("Oops, something went wrong!");
|
||||
setStatusDetails(
|
||||
"An error occurred during the OAuth finalization process. Please try again."
|
||||
);
|
||||
setIsError(true);
|
||||
}
|
||||
};
|
||||
|
||||
onFirstLoad();
|
||||
}, [credential, connector]);
|
||||
|
||||
useEffect(() => {}, [redirectUrl]);
|
||||
|
||||
return (
|
||||
<div className="mx-auto h-screen flex flex-col">
|
||||
<AdminPageTitle title={pageTitle} icon={<KeyIcon size={32} />} />
|
||||
|
||||
<div className="flex-1 flex flex-col items-center justify-center">
|
||||
<CardSection className="max-w-md w-[500px] h-[250px] p-8">
|
||||
<h1 className="text-2xl font-bold mb-4">{statusMessage}</h1>
|
||||
<p className="text-text-500">{statusDetails}</p>
|
||||
|
||||
<Formik
|
||||
initialValues={{
|
||||
credential_id: credential,
|
||||
cloud_id: "",
|
||||
cloud_name: "",
|
||||
cloud_url: "",
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
credential_id: Yup.number().required(
|
||||
"Credential ID is required."
|
||||
),
|
||||
cloud_id: Yup.string().required(
|
||||
"You must select a Confluence site (id not found)."
|
||||
),
|
||||
cloud_name: Yup.string().required(
|
||||
"You must select a Confluence site (name not found)."
|
||||
),
|
||||
cloud_url: Yup.string().required(
|
||||
"You must select a Confluence site (url not found)."
|
||||
),
|
||||
})}
|
||||
validateOnMount
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
try {
|
||||
if (!values.cloud_id) {
|
||||
throw new Error("Cloud ID is required.");
|
||||
}
|
||||
|
||||
if (!values.cloud_name) {
|
||||
throw new Error("Cloud URL is required.");
|
||||
}
|
||||
|
||||
if (!values.cloud_url) {
|
||||
throw new Error("Cloud URL is required.");
|
||||
}
|
||||
|
||||
const response = await handleOAuthConfluenceFinalize(
|
||||
values.credential_id,
|
||||
values.cloud_id,
|
||||
values.cloud_name,
|
||||
values.cloud_url
|
||||
);
|
||||
formikHelpers.setSubmitting(false);
|
||||
|
||||
if (response) {
|
||||
setRedirectUrl(response.redirect_url);
|
||||
setStatusMessage("Confluence authorization finalized.");
|
||||
}
|
||||
|
||||
setIsSubmitted(true); // Mark as submitted
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
setStatusMessage("Error during submission.");
|
||||
setStatusDetails(
|
||||
"An error occurred during the submission process. Please try again."
|
||||
);
|
||||
setIsError(true);
|
||||
formikHelpers.setSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, isValid, setFieldValue }) => (
|
||||
<Form>
|
||||
{/* Debug info
|
||||
<div className="mb-4 p-2 bg-gray-100 rounded text-xs">
|
||||
<pre>
|
||||
isValid: {String(isValid)}
|
||||
errors: {JSON.stringify(errors, null, 2)}
|
||||
values: {JSON.stringify(values, null, 2)}
|
||||
</pre>
|
||||
</div> */}
|
||||
|
||||
{/* Our helper component that reacts to changes in cloud_id */}
|
||||
<UpdateCloudURLOnCloudIdChange
|
||||
accessibleResources={accessibleResources}
|
||||
/>
|
||||
|
||||
<Field type="hidden" name="cloud_name" />
|
||||
<ErrorMessage
|
||||
name="cloud_name"
|
||||
component="div"
|
||||
className="error"
|
||||
/>
|
||||
|
||||
<Field type="hidden" name="cloud_url" />
|
||||
<ErrorMessage
|
||||
name="cloud_url"
|
||||
component="div"
|
||||
className="error"
|
||||
/>
|
||||
|
||||
{!redirectUrl && accessibleResources.length > 0 && (
|
||||
<SelectorFormField
|
||||
name="cloud_id"
|
||||
options={accessibleResources.map((resource) => ({
|
||||
name: `${resource.name} - ${resource.url}`,
|
||||
value: resource.id,
|
||||
}))}
|
||||
onSelect={(selectedValue) => {
|
||||
const selectedResource = accessibleResources.find(
|
||||
(resource) => resource.id === selectedValue
|
||||
);
|
||||
if (selectedResource) {
|
||||
setFieldValue("cloud_id", selectedResource.id);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<br />
|
||||
{!redirectUrl && (
|
||||
<Button
|
||||
type="submit"
|
||||
size="sm"
|
||||
variant="submit"
|
||||
disabled={!isValid || isSubmitting}
|
||||
>
|
||||
{isSubmitting ? "Submitting..." : "Submit"}
|
||||
</Button>
|
||||
)}
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
{redirectUrl && !isError && (
|
||||
<div className="mt-4">
|
||||
<p className="text-sm">
|
||||
Authorization finalized. Click{" "}
|
||||
<a href={redirectUrl} className="text-blue-500 underline">
|
||||
here
|
||||
</a>{" "}
|
||||
to continue.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</CardSection>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user