Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/schema-translate-map

# Conflicts:
#	backend/ee/onyx/server/oauth.py
#	backend/onyx/background/celery/tasks/indexing/tasks.py
#	backend/onyx/db/search_settings.py
#	backend/onyx/key_value_store/store.py
#	backend/onyx/onyxbot/slack/handlers/handle_buttons.py
#	backend/tests/integration/common_utils/reset.py
This commit is contained in:
Richard Kuo (Danswer) 2025-03-03 15:08:17 -08:00
commit 7acbadd825
145 changed files with 5359 additions and 1394 deletions

1
.github/CODEOWNERS vendored Normal file
View File

@ -0,0 +1 @@
* @onyx-dot-app/onyx-core-team

View File

@ -53,24 +53,90 @@ jobs:
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: ${{ always() }}
if: always()
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@0.29.0
# with:
# sarif_file: trivy-results.sarif
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0

View File

@ -0,0 +1,55 @@
"""add background_reindex_enabled field
Revision ID: b7c2b63c4a03
Revises: f11b408e39d3
Create Date: 2024-03-26 12:34:56.789012
"""
from alembic import op
import sqlalchemy as sa
from onyx.db.enums import EmbeddingPrecision
# revision identifiers, used by Alembic.
revision = "b7c2b63c4a03"
down_revision = "f11b408e39d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add background_reindex_enabled column with default value of True
op.add_column(
"search_settings",
sa.Column(
"background_reindex_enabled",
sa.Boolean(),
nullable=False,
server_default="true",
),
)
# Add embedding_precision column with default value of FLOAT
op.add_column(
"search_settings",
sa.Column(
"embedding_precision",
sa.Enum(EmbeddingPrecision, native_enum=False),
nullable=False,
server_default=EmbeddingPrecision.FLOAT.name,
),
)
# Add reduced_dimension column with default value of None
op.add_column(
"search_settings",
sa.Column("reduced_dimension", sa.Integer(), nullable=True),
)
def downgrade() -> None:
# Remove the background_reindex_enabled column
op.drop_column("search_settings", "background_reindex_enabled")
op.drop_column("search_settings", "embedding_precision")
op.drop_column("search_settings", "reduced_dimension")

View File

@ -0,0 +1,36 @@
"""force lowercase all users
Revision ID: f11b408e39d3
Revises: 3bd4c84fe72f
Create Date: 2025-02-26 17:04:55.683500
"""
# revision identifiers, used by Alembic.
revision = "f11b408e39d3"
down_revision = "3bd4c84fe72f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing user emails to lowercase
from alembic import op
op.execute(
"""
UPDATE "user"
SET email = LOWER(email)
"""
)
# 2) Add a check constraint to ensure emails are always lowercase
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
def downgrade() -> None:
# Drop the check constraint
from alembic import op
op.drop_constraint("ensure_lowercase_email", "user", type_="check")

View File

@ -0,0 +1,42 @@
"""lowercase multi-tenant user auth
Revision ID: 34e3630c7f32
Revises: a4f6ee863c47
Create Date: 2025-02-26 15:03:01.211894
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "34e3630c7f32"
down_revision = "a4f6ee863c47"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing rows to lowercase
op.execute(
"""
UPDATE user_tenant_mapping
SET email = LOWER(email)
"""
)
# 2) Add a check constraint so that emails cannot be written in uppercase
op.create_check_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
"email = LOWER(email)",
schema="public",
)
def downgrade() -> None:
# Drop the check constraint
op.drop_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
schema="public",
type_="check",
)

View File

@ -59,10 +59,14 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
)
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
)
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""

View File

@ -424,7 +424,7 @@ def _validate_curator_status__no_commit(
)
# if the user is a curator in any of their groups, set their role to CURATOR
# otherwise, set their role to BASIC
# otherwise, set their role to BASIC only if they were previously a CURATOR
if curator_relationships:
user.role = UserRole.CURATOR
elif user.role == UserRole.CURATOR:
@ -631,7 +631,16 @@ def update_user_group(
removed_users = db_session.scalars(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
_validate_curator_status__no_commit(db_session, list(removed_users))
# Filter out admin and global curator users before validating curator status
users_to_validate = [
user
for user in removed_users
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
]
if users_to_validate:
_validate_curator_status__no_commit(db_session, users_to_validate)
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()

View File

@ -9,12 +9,16 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@ -342,7 +346,8 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@ -354,7 +359,11 @@ def confluence_doc_sync(
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
confluence_connector.load_credentials(cc_pair.credential.credential_json)
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), "confluence", cc_pair.credential_id
)
confluence_connector.set_credentials_provider(provider)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)

View File

@ -1,9 +1,11 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.background.error_logging import emit_background_error
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@ -61,13 +63,27 @@ def _build_group_member_email_map(
def confluence_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
confluence_client = build_confluence_client(
credentials=cc_pair.credential.credential_json,
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
)
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
url = wiki_base.rstrip("/")
probe_kwargs = {
"max_backoff_retries": 6,
"max_backoff_seconds": 10,
}
final_kwargs = {
"max_backoff_retries": 10,
"max_backoff_seconds": 60,
}
confluence_client = OnyxConfluence(is_cloud, url, provider)
confluence_client._probe_connection(**probe_kwargs)
confluence_client._initialize_connection(**final_kwargs)
group_member_email_map = _build_group_member_email_map(
confluence_client=confluence_client,

View File

@ -32,7 +32,8 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@ -145,7 +145,8 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@ -119,6 +119,7 @@ def _build_onyx_groups(
def gdrive_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects

View File

@ -123,7 +123,8 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@ -28,6 +28,7 @@ DocSyncFuncType = Callable[
GroupSyncFuncType = Callable[
[
str,
ConnectorCredentialPair,
],
list[ExternalUserGroup],

View File

@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth import router as oauth_router
from ee.onyx.server.oauth.api import router as oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@ -152,4 +152,8 @@ def get_application() -> FastAPI:
# environment variable. Used to automate deployment for multiple environments.
seed_db()
# for debugging discovered routes
# for route in application.router.routes:
# print(f"Path: {route.path}, Methods: {route.methods}")
return application

View File

@ -22,7 +22,7 @@ from onyx.onyxbot.slack.blocks import get_restate_blocks
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.utils.logger import OnyxLoggingAdapter
from onyx.utils.logger import setup_logger
@ -216,7 +216,7 @@ def _handle_standard_answers(
all_blocks = restate_question_blocks + answer_blocks
try:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=message_info.channel_to_respond,
receiver_ids=receiver_ids,
@ -231,6 +231,7 @@ def _handle_standard_answers(
client=client,
channel=message_info.channel_to_respond,
thread_ts=slack_thread_id,
receiver_ids=receiver_ids,
)
return True

View File

@ -0,0 +1,91 @@
import base64
import uuid
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from ee.onyx.server.oauth.api_router import router
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
from ee.onyx.server.oauth.slack import SlackOAuth
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import DocumentSource
from onyx.db.engine import get_current_tenant_id
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_admin_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str | None = None
if connector == DocumentSource.SLACK:
if not DEV_MODE:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.CONFLUENCE:
if not DEV_MODE:
oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state)
session = ConfluenceCloudOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
if not DEV_MODE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
if not session:
raise HTTPException(
status_code=500,
detail=f"The document source type {connector} failed to generate an OAuth session.",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})

View File

@ -0,0 +1,3 @@
from fastapi import APIRouter
router: APIRouter = APIRouter(prefix="/oauth")

View File

@ -0,0 +1,361 @@
import base64
import uuid
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
from onyx.db.credentials import create_credential
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import update_credential_json
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ConfluenceCloudOAuth:
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
class AccessibleResources(BaseModel):
id: str
name: str
url: str
scopes: list[str]
avatarUrl: str
CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL
ACCESSIBLE_RESOURCE_URL = (
"https://api.atlassian.com/oauth/token/accessible-resources"
)
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
# classic scope
"read:confluence-space.summary%20"
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence%20"
"search:confluence%20"
# granular scope
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
"offline_access"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&redirect_uri={redirect_uri}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json)
return session
@classmethod
def generate_finalize_url(cls, credential_id: int) -> str:
return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}"
@router.post("/connector/confluence/callback")
def confluence_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Handles the backend logic for the frontend page that the user is redirected to
after visiting the oauth authorization url."""
if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Confluence Cloud client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = ConfluenceCloudOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI
else:
redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
ConfluenceCloudOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": ConfluenceCloudOAuth.CLIENT_ID,
"client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
token_response: ConfluenceCloudOAuth.TokenResponse | None = None
try:
token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json(
response.text
)
except Exception:
raise RuntimeError(
"Confluence Cloud OAuth failed during code/token exchange."
)
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=token_response.expires_in)
credential_info = CredentialBase(
credential_json={
"confluence_access_token": token_response.access_token,
"confluence_refresh_token": token_response.refresh_token,
"created_at": now.isoformat(),
"expires_at": expires_at.isoformat(),
"expires_in": token_response.expires_in,
"scope": token_response.scope,
},
admin_public=True,
source=DocumentSource.CONFLUENCE,
name="Confluence Cloud OAuth",
)
credential = create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud OAuth completed successfully.",
"finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id),
"redirect_on_success": session.redirect_on_success,
}
)
@router.get("/connector/confluence/accessible-resources")
def confluence_oauth_accessible_resources(
credential_id: int,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Atlassian's API is weird and does not supply us with enough info to be in a
usable state after authorizing. All API's require a cloud id. We have to list
the accessible resources/sites and let the user choose which site to use."""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if not credential:
raise HTTPException(400, f"Credential {credential_id} not found.")
credential_dict = credential.credential_json
access_token = credential_dict["confluence_access_token"]
try:
# Exchange the authorization code for an access token
response = requests.get(
ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL,
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
},
)
response.raise_for_status()
accessible_resources_data = response.json()
# Validate the list of AccessibleResources
try:
accessible_resources = [
ConfluenceCloudOAuth.AccessibleResources(**resource)
for resource in accessible_resources_data
]
except ValidationError as e:
raise RuntimeError(f"Failed to parse accessible resources: {e}")
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}",
},
)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud get accessible resources completed successfully.",
"accessible_resources": [
resource.model_dump() for resource in accessible_resources
],
}
)
@router.post("/connector/confluence/finalize")
def confluence_oauth_finalize(
credential_id: int,
cloud_id: str,
cloud_name: str,
cloud_url: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Saves the info for the selected cloud site to the credential.
This is the final step in the confluence oauth flow where after the traditional
OAuth process, the user has to select a site to associate with the credentials.
After this, the credential is usable."""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if not credential:
raise HTTPException(
status_code=400,
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
)
new_credential_json: dict[str, Any] = dict(credential.credential_json)
new_credential_json["cloud_id"] = cloud_id
new_credential_json["cloud_name"] = cloud_name
new_credential_json["wiki_base"] = cloud_url
try:
update_credential_json(credential_id, new_credential_json, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
},
)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud OAuth finalized successfully.",
"redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence",
}
)

View File

@ -0,0 +1,229 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = GoogleDriveOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = GoogleDriveOAuth.REDIRECT_URI
else:
redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"finalize_url": None,
"redirect_on_success": session.redirect_on_success,
}
)

View File

@ -0,0 +1,197 @@
import base64
import uuid
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.danswer.dev/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = SlackOAuth.REDIRECT_URI
else:
redirect_uri = SlackOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"finalize_url": None,
"redirect_on_success": session.redirect_on_success,
"team_id": team_id,
"authed_user_id": authed_user_id,
}
)

View File

@ -138,6 +138,7 @@ def get_user_chat_sessions(
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,

View File

@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()
def fetch_billing_information(tenant_id: str) -> BillingInformation:
def fetch_billing_information(
tenant_id: str,
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation:
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
billing_info = BillingInformation(**response.json())
return billing_info
response_data = response.json()
# Check if the response indicates no subscription
if (
isinstance(response_data, dict)
and "subscribed" in response_data
and not response_data["subscribed"]
):
return SubscriptionStatusResponse(**response_data)
# Otherwise, parse as BillingInformation
return BillingInformation(**response_data)
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:

View File

@ -78,7 +78,7 @@ class CloudEmbedding:
self._closed = False
async def _embed_openai(
self, texts: list[str], model: str | None
self, texts: list[str], model: str | None, reduced_dimension: int | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
@ -91,7 +91,11 @@ class CloudEmbedding:
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(input=text_batch, model=model)
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
@ -223,9 +227,10 @@ class CloudEmbedding:
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name)
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
@ -326,6 +331,7 @@ async def embed_text(
prefix: str | None,
api_url: str | None,
api_version: str | None,
reduced_dimension: int | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
if not all(texts):
@ -369,6 +375,7 @@ async def embed_text(
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
reduced_dimension=reduced_dimension,
)
if any(embedding is None for embedding in embeddings):
@ -508,6 +515,7 @@ async def process_embed_request(
text_type=embed_request.text_type,
api_url=embed_request.api_url,
api_version=embed_request.api_version,
reduced_dimension=embed_request.reduced_dimension,
prefix=prefix,
gpu_type=gpu_type,
)

View File

@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"refresh_token": refresh_token,
}
user: User
user: User | None = None
try:
# Attempt to get user by OAuth account
@ -420,15 +420,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = cast(User, await self.user_db.get_by_email(account_email))
user = await self.user_db.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# Make sure user is not None before adding OAuth account
if user is not None:
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
else:
# This shouldn't happen since get_by_email would raise UserNotExists
# but adding as a safeguard
raise exceptions.UserNotExists()
# If user not found by OAuth account or email, create a new user
except exceptions.UserNotExists:
password = self.password_helper.generate()
user_dict = {
@ -439,26 +444,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = await self.user_db.create(user_dict)
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
# Add OAuth account only if user creation was successful
if user is not None:
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
raise HTTPException(
status_code=500, detail="Failed to create user account"
)
else:
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# User exists, update OAuth account if needed
if user is not None: # Add explicit check
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# Ensure user is not None before proceeding
if user is None:
raise HTTPException(
status_code=500, detail="Failed to authenticate or create user"
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled

View File

@ -423,7 +423,7 @@ def connector_external_group_sync_generator_task(
)
external_user_groups: list[ExternalUserGroup] = []
try:
external_user_groups = ext_group_sync_func(cc_pair)
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(

View File

@ -23,9 +23,9 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.indexing.utils import _should_index
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import should_index
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
@ -61,7 +61,7 @@ from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_current_search_settings
from onyx.db.session import get_session_with_current_tenant
from onyx.db.swap_index import check_index_swap
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from onyx.redis.redis_connector import RedisConnector
@ -406,7 +406,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# check for search settings swap
with get_session_with_current_tenant() as db_session:
old_search_settings = check_index_swap(db_session=db_session)
old_search_settings = check_and_perform_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
@ -439,6 +439,15 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
with get_session_with_current_tenant() as db_session:
search_settings_list = get_active_search_settings_list(db_session)
for search_settings_instance in search_settings_list:
# skip non-live search settings that don't have background reindex enabled
# those should just auto-change to live shortly after creation without
# requiring any indexing till that point
if (
not search_settings_instance.status.is_current()
and not search_settings_instance.background_reindex_enabled
):
continue
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@ -456,23 +465,18 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
if not should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if search_settings_instance.status.is_current():
# the indexing trigger is only checked and cleared with the current search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True

View File

@ -346,11 +346,10 @@ def validate_indexing_fences(
return
def _should_index(
def should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@ -415,9 +414,9 @@ def _should_index(
):
return False
if search_settings_primary:
if search_settings_instance.status.is_current():
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
# if a manual indexing trigger is on the cc pair, honor it for live search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq

View File

@ -11,10 +11,27 @@ def emit_background_error(
"""Currently just saves a row in the background_errors table.
In the future, could create notifications based on the severity."""
with get_session_with_current_tenant() as db_session:
try:
error_message = ""
# try to write to the db, but handle IntegrityError specifically
try:
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)
except IntegrityError as e:
# Log an error if the cc_pair_id was deleted or any other exception occurs
error_message = f"Failed to create background error: {str(e)}. Original message: {message}"
except IntegrityError as e:
# Log an error if the cc_pair_id was deleted or any other exception occurs
error_message = (
f"Failed to create background error: {str(e)}. Original message: {message}"
)
except Exception:
pass
if not error_message:
return
# if we get here from an IntegrityError, try to write the error message to the db
# we need a new session because the first session is now invalid
try:
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, error_message, None)
except Exception:
pass

View File

@ -93,10 +93,11 @@ def _get_connector_runner(
runnable_connector.validate_connector_settings()
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
logger.exception("Unable to instantiate connector.")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed. Sometimes there are cases where the connector will
# it will never succeed
# Sometimes there are cases where the connector will
# intermittently fail to initialize in which case we should pass in
# leave_connector_active=True to allow it to continue.
# For example, if there is nightly maintenance on a Confluence Server instance,

View File

@ -11,17 +11,20 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import attachment_to_content
from onyx.connectors.confluence.onyx_confluence import (
extract_text_from_confluence_html,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import attachment_to_content
from onyx.connectors.confluence.utils import build_confluence_document_id
from onyx.connectors.confluence.utils import datetime_from_string
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@ -83,7 +86,9 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
)
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
class ConfluenceConnector(
LoadConnector, PollConnector, SlimConnector, CredentialsConnector
):
def __init__(
self,
wiki_base: str,
@ -102,7 +107,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self._confluence_client: OnyxConfluence | None = None
self.is_cloud = is_cloud
# Remove trailing slash from wiki_base if present
@ -137,6 +141,19 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
self.credentials_provider: CredentialsProviderInterface | None = None
self.probe_kwargs = {
"max_backoff_retries": 6,
"max_backoff_seconds": 10,
}
self.final_kwargs = {
"max_backoff_retries": 10,
"max_backoff_seconds": 60,
}
self._confluence_client: OnyxConfluence | None = None
@property
def confluence_client(self) -> OnyxConfluence:
@ -144,15 +161,22 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self._confluence_client = build_confluence_client(
credentials=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
self.credentials_provider = credentials_provider
# raises exception if there's a problem
confluence_client = OnyxConfluence(
self.is_cloud, self.wiki_base, credentials_provider
)
return None
confluence_client._probe_connection(**self.probe_kwargs)
confluence_client._initialize_connection(**self.final_kwargs)
self._confluence_client = confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError("Use set_credentials_provider with this connector.")
def _construct_page_query(
self,
@ -202,12 +226,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
return comment_string
def _convert_object_to_document(
self, confluence_object: dict[str, Any]
self,
confluence_object: dict[str, Any],
parent_content_id: str | None = None,
) -> Document | None:
"""
Takes in a confluence object, extracts all metadata, and converts it into a document.
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
parent_content_id: if the object is an attachment, specifies the content id that
the attachment is attached to
"""
# The url and the id are the same
object_url = build_confluence_document_id(
@ -226,7 +255,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
confluence_client=self.confluence_client, attachment=confluence_object
confluence_client=self.confluence_client,
attachment=confluence_object,
parent_content_id=parent_content_id,
)
if object_text is None:
@ -302,7 +333,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment)
doc = self._convert_object_to_document(attachment, confluence_page_id)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:

View File

@ -1,19 +1,37 @@
import math
import io
import json
import time
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypeVar
from urllib.parse import quote
import bs4
from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from redis import Redis
from requests import HTTPError
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -22,12 +40,14 @@ logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
class ConfluenceRateLimitError(Exception):
pass
@ -43,124 +63,349 @@ class ConfluenceUser(BaseModel):
type: str
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
class OnyxConfluence(Confluence):
class OnyxConfluence:
"""
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
This is a custom Confluence class that:
A. overrides the default Confluence class to add a custom CQL method.
B.
This is necessary because the default Confluence class does not properly support cql expansions.
All methods are automatically wrapped with handle_confluence_rate_limit.
"""
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
CREDENTIAL_PREFIX = "connector:confluence:credential"
CREDENTIAL_TTL = 300 # 5 min
def _wrap_methods(self) -> None:
def __init__(
self,
is_cloud: bool,
url: str,
credentials_provider: CredentialsProviderInterface,
) -> None:
self._is_cloud = is_cloud
self._url = url.rstrip("/")
self._credentials_provider = credentials_provider
self.redis_client: Redis | None = None
self.static_credentials: dict[str, Any] | None = None
if self._credentials_provider.is_dynamic():
self.redis_client = get_redis_client(
tenant_id=credentials_provider.get_tenant_id()
)
else:
self.static_credentials = self._credentials_provider.get_credentials()
self._confluence = Confluence(url)
self.credential_key: str = (
self.CREDENTIAL_PREFIX
+ f":credential_{self._credentials_provider.get_provider_key()}"
)
self._kwargs: Any = None
self.shared_base_kwargs = {
"api_version": "cloud" if is_cloud else "latest",
"backoff_and_retry": True,
"cloud": is_cloud,
}
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
"""credential_json - the current json credentials
Returns a tuple
1. The up to date credentials
2. True if the credentials were updated
This method is intended to be used within a distributed lock.
Lock, call this, update credentials if the tokens were refreshed, then release
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
wrap it with handle_confluence_rate_limit.
"""
for attr_name in dir(self):
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
setattr(
self,
attr_name,
handle_confluence_rate_limit(getattr(self, attr_name)),
# static credentials are preloaded, so no locking/redis required
if self.static_credentials:
return self.static_credentials, False
if not self.redis_client:
raise RuntimeError("self.redis_client is None")
# dynamic credentials need locking
# check redis first, then fallback to the DB
credential_raw = self.redis_client.get(self.credential_key)
if credential_raw is not None:
credential_bytes = cast(bytes, credential_raw)
credential_str = credential_bytes.decode("utf-8")
credential_json: dict[str, Any] = json.loads(credential_str)
else:
credential_json = self._credentials_provider.get_credentials()
if "confluence_refresh_token" not in credential_json:
# static credentials ... cache them permanently and return
self.static_credentials = credential_json
return credential_json, False
# check if we should refresh tokens. we're deciding to refresh halfway
# to expiration
now = datetime.now(timezone.utc)
created_at = datetime.fromisoformat(credential_json["created_at"])
expires_in: int = credential_json["expires_in"]
renew_at = created_at + timedelta(seconds=expires_in // 2)
if now <= renew_at:
# cached/current credentials are reasonably up to date
return credential_json, False
# we need to refresh
logger.info("Renewing Confluence Cloud credentials...")
new_credentials = confluence_refresh_tokens(
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID,
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET,
credential_json["cloud_id"],
credential_json["confluence_refresh_token"],
)
# store the new credentials to redis and to the db thru the provider
# redis: we use a 5 min TTL because we are given a 10 minute grace period
# when keys are rotated. it's easier to expire the cached credentials
# reasonably frequently rather than trying to handle strong synchronization
# between the db and redis everywhere the credentials might be updated
new_credential_str = json.dumps(new_credentials)
self.redis_client.set(
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
)
self._credentials_provider.set_credentials(new_credentials)
return new_credentials, True
@staticmethod
def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]:
oauth2_dict: dict[str, Any] = {}
if "confluence_refresh_token" in credentials:
oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
oauth2_dict["token"] = {}
oauth2_dict["token"]["access_token"] = credentials[
"confluence_access_token"
]
return oauth2_dict
def _probe_connection(
self,
**kwargs: Any,
) -> None:
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
with self._credentials_provider:
credentials, _ = self._renew_credentials()
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
logger.info("Probing Confluence with OAuth Access Token.")
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
credentials
)
url = (
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
)
confluence_client_with_minimal_retries = Confluence(
url=url, oauth2=oauth2_dict, **merged_kwargs
)
else:
logger.info("Probing Confluence with Personal Access Token.")
url = self._url
if self._is_cloud:
confluence_client_with_minimal_retries = Confluence(
url=url,
username=credentials["confluence_username"],
password=credentials["confluence_access_token"],
**merged_kwargs,
)
else:
confluence_client_with_minimal_retries = Confluence(
url=url,
token=credentials["confluence_access_token"],
**merged_kwargs,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
if not spaces:
raise RuntimeError(
f"No spaces found at {url}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
logger.info("Confluence probe succeeded.")
def _initialize_connection(
self,
**kwargs: Any,
) -> None:
"""Called externally to init the connection in a thread safe manner."""
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
with self._credentials_provider:
credentials, _ = self._renew_credentials()
self._confluence = self._initialize_connection_helper(
credentials, **merged_kwargs
)
self._kwargs = merged_kwargs
def _initialize_connection_helper(
self,
credentials: dict[str, Any],
**kwargs: Any,
) -> Confluence:
"""Called internally to init the connection. Distributed locking
to prevent multiple threads from modifying the credentials
must be handled around this function."""
confluence = None
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
logger.info("Connecting to Confluence Cloud with OAuth Access Token.")
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials)
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
else:
logger.info("Connecting to Confluence with Personal Access Token.")
if self._is_cloud:
confluence = Confluence(
url=self._url,
username=credentials["confluence_username"],
password=credentials["confluence_access_token"],
**kwargs,
)
else:
confluence = Confluence(
url=self._url,
token=credentials["confluence_access_token"],
**kwargs,
)
return confluence
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def _make_rate_limited_confluence_method(
self, name: str, credential_provider: CredentialsProviderInterface | None
) -> Callable[..., Any]:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
try:
if credential_provider:
with credential_provider:
credentials, renewed = self._renew_credentials()
if renewed:
self._confluence = self._initialize_connection_helper(
credentials, **self._kwargs
)
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
return attr(*args, **kwargs)
else:
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
return attr(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return wrapped_call
# def _wrap_methods(self) -> None:
# """
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
# wrap it with handle_confluence_rate_limit.
# """
# for attr_name in dir(self):
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
# setattr(
# self,
# attr_name,
# handle_confluence_rate_limit(getattr(self, attr_name)),
# )
# def _ensure_token_valid(self) -> None:
# if self._token_is_expired():
# self._refresh_token()
# # Re-init the Confluence client with the originally stored args
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
def __getattr__(self, name: str) -> Any:
"""Dynamically intercept attribute/method access."""
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
# If it's not a method, just return it after ensuring token validity
if not callable(attr):
return attr
# skip methods that start with "_"
if name.startswith("_"):
return attr
# wrap the method with our retry handler
rate_limited_method: Callable[
..., Any
] = self._make_rate_limited_confluence_method(name, self._credentials_provider)
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
return rate_limited_method(*args, **kwargs)
return wrapped_method
def _paginate_url(
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
@ -507,63 +752,212 @@ class OnyxConfluence(Confluence):
return response
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_with_minimal_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=6,
max_backoff_seconds=10,
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
logger.warning(f"failed to get confluence email for {user_name}")
# For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync.
email = None
# We may want to just return a string that indicates failure so we dont
# keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
user_id (str): The user id (i.e: the account-id or userkey)
confluence_client (Confluence): The Confluence Client
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
"""
global _USER_ID_TO_DISPLAY_NAME_CACHE
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
if not found_display_name:
try:
result = confluence_client.get_user_details_by_accountid(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
# why are we using session.get here? we probably won't retry these ... is that ok?
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
return extracted_text
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
body = confluence_object["body"]
object_html = body.get("storage", body.get("view", {})).get("value")
soup = bs4.BeautifulSoup(object_html, "html.parser")
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
if "ri:account-id" in user.attrs
else user.get("ri:userkey")
)
if not user_id:
logger.warning(
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
)
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
try:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
except Exception as e:
raise ConnectorValidationError(str(e))
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
cloud=is_cloud,
)
return format_document_soup(soup)

View File

@ -1,185 +1,38 @@
import io
import math
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import TypeVar
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urlparse
import bs4
import requests
from pydantic import BaseModel
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
pass
logger = setup_logger()
_USER_EMAIL_CACHE: dict[str, str | None] = {}
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
def get_user_email_from_username__server(
confluence_client: "OnyxConfluence", user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
logger.warning(f"failed to get confluence email for {user_name}")
# For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync.
email = None
# We may want to just return a string that indicates failure so we dont
# keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
user_id (str): The user id (i.e: the account-id or userkey)
confluence_client (Confluence): The Confluence Client
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
"""
global _USER_ID_TO_DISPLAY_NAME_CACHE
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
if not found_display_name:
try:
result = confluence_client.get_user_details_by_accountid(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def extract_text_from_confluence_html(
confluence_client: "OnyxConfluence",
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
body = confluence_object["body"]
object_html = body.get("storage", body.get("view", {})).get("value")
soup = bs4.BeautifulSoup(object_html, "html.parser")
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
if "ri:account-id" in user.attrs
else user.get("ri:userkey")
)
if not user_id:
logger.warning(
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
)
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
@ -193,49 +46,6 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
]
def attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def build_confluence_document_id(
base_url: str, content_url: str, is_cloud: bool
) -> str:
@ -284,6 +94,137 @@ def datetime_from_string(datetime_string: str) -> datetime:
return datetime_object
def confluence_refresh_tokens(
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
) -> dict[str, Any]:
# rotate the refresh and access token
# Note that access tokens are only good for an hour in confluence cloud,
# so we're going to have problems if the connector runs for longer
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair
response = requests.post(
CONFLUENCE_OAUTH_TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "refresh_token",
"client_id": client_id,
"client_secret": client_secret,
"refresh_token": refresh_token,
},
)
try:
token_response = TokenResponse.model_validate_json(response.text)
except Exception:
raise RuntimeError("Confluence Cloud token refresh failed.")
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=token_response.expires_in)
new_credentials: dict[str, Any] = {}
new_credentials["confluence_access_token"] = token_response.access_token
new_credentials["confluence_refresh_token"] = token_response.refresh_token
new_credentials["created_at"] = now.isoformat()
new_credentials["expires_at"] = expires_at.isoformat()
new_credentials["expires_in"] = token_response.expires_in
new_credentials["scope"] = token_response.scope
new_credentials["cloud_id"] = cloud_id
return new_credentials
F = TypeVar("F", bound=Callable[..., Any])
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except requests.HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
def get_single_param_from_url(url: str, param: str) -> str | None:
"""Get a parameter from a url"""
parsed_url = urlparse(url)

View File

@ -0,0 +1,135 @@
import uuid
from types import TracebackType
from typing import Any
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import Credential
from onyx.redis.redis_pool import get_redis_client
class OnyxDBCredentialsProvider(
CredentialsProviderInterface["OnyxDBCredentialsProvider"]
):
"""Implementation to allow the connector to callback and update credentials in the db.
Required in cases where credentials can rotate while the connector is running.
"""
LOCK_TTL = 900 # TTL of the lock
def __init__(self, tenant_id: str, connector_name: str, credential_id: int):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_id = credential_id
self.redis_client = get_redis_client(tenant_id=tenant_id)
# lock used to prevent overlapping renewal of credentials
self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}"
self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL)
def __enter__(self) -> "OnyxDBCredentialsProvider":
acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL)
if not acquired:
raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}")
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Release the lock when exiting the context."""
if self._lock and self._lock.owned():
self._lock.release()
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return str(self._credential_id)
def get_credentials(self) -> dict[str, Any]:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
credential = db_session.execute(
select(Credential).where(Credential.id == self._credential_id)
).scalar_one()
if credential is None:
raise ValueError(
f"No credential found: credential={self._credential_id}"
)
return credential.credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
try:
credential = db_session.execute(
select(Credential)
.where(Credential.id == self._credential_id)
.with_for_update()
).scalar_one()
if credential is None:
raise ValueError(
f"No credential found: credential={self._credential_id}"
)
credential.credential_json = credential_json
db_session.commit()
except Exception:
db_session.rollback()
raise
def is_dynamic(self) -> bool:
return True
class OnyxStaticCredentialsProvider(
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
):
"""Implementation (a very simple one!) to handle static credentials."""
def __init__(
self,
tenant_id: str | None,
connector_name: str,
credential_json: dict[str, Any],
):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_json = credential_json
self._provider_key = str(uuid.uuid4())
def __enter__(self) -> "OnyxStaticCredentialsProvider":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return self._provider_key
def get_credentials(self) -> dict[str, Any]:
return self._credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
self._credential_json = credential_json
def is_dynamic(self) -> bool:
return False

View File

@ -12,6 +12,7 @@ from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.bookstack.connector import BookstackConnector
from onyx.connectors.clickup.connector import ClickupConnector
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.discord.connector import DiscordConnector
from onyx.connectors.discourse.connector import DiscourseConnector
from onyx.connectors.document360.connector import Document360Connector
@ -32,6 +33,7 @@ from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@ -57,6 +59,7 @@ from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.models import Credential
from shared_configs.contextvars import get_current_tenant_id
class ConnectorMissingException(Exception):
@ -167,10 +170,17 @@ def instantiate_connector(
connector_class = identify_connector_class(source, input_type)
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json)
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
if isinstance(connector, CredentialsConnector):
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), str(source), credential.id
)
connector.set_credentials_provider(provider)
else:
new_credentials = connector.load_credentials(credential.credential_json)
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
return connector

View File

@ -1,7 +1,10 @@
import abc
from collections.abc import Generator
from collections.abc import Iterator
from types import TracebackType
from typing import Any
from typing import Generic
from typing import TypeVar
from pydantic import BaseModel
@ -111,6 +114,69 @@ class OAuthConnector(BaseConnector):
raise NotImplementedError
T = TypeVar("T", bound="CredentialsProviderInterface")
class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod
def __enter__(self) -> T:
raise NotImplementedError
@abc.abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
raise NotImplementedError
@abc.abstractmethod
def get_tenant_id(self) -> str | None:
raise NotImplementedError
@abc.abstractmethod
def get_provider_key(self) -> str:
"""a unique key that the connector can use to lock around a credential
that might be used simultaneously.
Will typically be the credential id, but can also just be something random
in cases when there is nothing to lock (aka static credentials)
"""
raise NotImplementedError
@abc.abstractmethod
def get_credentials(self) -> dict[str, Any]:
raise NotImplementedError
@abc.abstractmethod
def set_credentials(self, credential_json: dict[str, Any]) -> None:
raise NotImplementedError
@abc.abstractmethod
def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client
needs to use the locking features of the credentials provider to operate
correctly.
If static, the client can simply reference the credentials once and use them
through the entire indexing run.
"""
raise NotImplementedError
class CredentialsConnector(BaseConnector):
"""Implement this if the connector needs to be able to read and write credentials
on the fly. Typically used with shared credentials/tokens that might be renewed
at any time."""
@abc.abstractmethod
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
raise NotImplementedError
# Event driven
class EventConnector(BaseConnector):
@abc.abstractmethod

View File

@ -72,6 +72,7 @@ def make_slack_api_rate_limited(
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse:
last_exception = None
for _ in range(max_retries):
try:
# Make the API call

View File

@ -42,6 +42,10 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
# Threshold for determining when to replace vs append iframe content
IFRAME_TEXT_LENGTH_THRESHOLD = 700
# Message indicating JavaScript is disabled, which often appears when scraping fails
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
@ -138,7 +142,8 @@ def get_internal_links(
# Account for malformed backslashes in URLs
href = href.replace("\\", "/")
if should_ignore_pound and "#" in href:
# "#!" indicates the page is using a hashbang URL, which is a client-side routing technique
if should_ignore_pound and "#" in href and "#!" not in href:
href = href.split("#")[0]
if not is_valid_url(href):
@ -288,6 +293,7 @@ class WebConnector(LoadConnector):
and converts them into documents"""
visited_links: set[str] = set()
to_visit: list[str] = self.to_visit_list
content_hashes = set()
if not to_visit:
raise ValueError("No URLs to visit")
@ -302,29 +308,30 @@ class WebConnector(LoadConnector):
playwright, context = start_playwright()
restart_playwright = False
while to_visit:
current_url = to_visit.pop()
if current_url in visited_links:
initial_url = to_visit.pop()
if initial_url in visited_links:
continue
visited_links.add(current_url)
visited_links.add(initial_url)
try:
protected_url_check(current_url)
protected_url_check(initial_url)
except Exception as e:
last_error = f"Invalid URL {current_url} due to {e}"
last_error = f"Invalid URL {initial_url} due to {e}"
logger.warning(last_error)
continue
logger.info(f"Visiting {current_url}")
index = len(visited_links)
logger.info(f"{index}: Visiting {initial_url}")
try:
check_internet_connection(current_url)
check_internet_connection(initial_url)
if restart_playwright:
playwright, context = start_playwright()
restart_playwright = False
if current_url.split(".")[-1] == "pdf":
if initial_url.split(".")[-1] == "pdf":
# PDF files are not checked for links
response = requests.get(current_url)
response = requests.get(initial_url)
page_text, metadata = read_pdf_file(
file=io.BytesIO(response.content)
)
@ -332,10 +339,10 @@ class WebConnector(LoadConnector):
doc_batch.append(
Document(
id=current_url,
sections=[Section(link=current_url, text=page_text)],
id=initial_url,
sections=[Section(link=initial_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=current_url.split("/")[-1],
semantic_identifier=initial_url.split("/")[-1],
metadata=metadata,
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
@ -347,21 +354,29 @@ class WebConnector(LoadConnector):
continue
page = context.new_page()
page_response = page.goto(current_url)
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
page_response = page.goto(
initial_url,
timeout=30000, # 30 seconds
)
last_modified = (
page_response.header_value("Last-Modified")
if page_response
else None
)
final_page = page.url
if final_page != current_url:
logger.info(f"Redirected to {final_page}")
protected_url_check(final_page)
current_url = final_page
if current_url in visited_links:
logger.info("Redirected page already indexed")
final_url = page.url
if final_url != initial_url:
protected_url_check(final_url)
initial_url = final_url
if initial_url in visited_links:
logger.info(
f"{index}: {initial_url} redirected to {final_url} - already indexed"
)
continue
visited_links.add(current_url)
logger.info(f"{index}: {initial_url} redirected to {final_url}")
visited_links.add(initial_url)
if self.scroll_before_scraping:
scroll_attempts = 0
@ -379,26 +394,58 @@ class WebConnector(LoadConnector):
soup = BeautifulSoup(content, "html.parser")
if self.recursive:
internal_links = get_internal_links(base_url, current_url, soup)
internal_links = get_internal_links(base_url, initial_url, soup)
for link in internal_links:
if link not in visited_links:
to_visit.append(link)
if page_response and str(page_response.status)[0] in ("4", "5"):
last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response"
last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response"
logger.info(last_error)
continue
parsed_html = web_html_cleanup(soup, self.mintlify_cleanup)
"""For websites containing iframes that need to be scraped,
the code below can extract text from within these iframes.
"""
logger.debug(
f"{index}: Length of cleaned text {len(parsed_html.cleaned_text)}"
)
if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text:
iframe_count = page.frame_locator("iframe").locator("html").count()
if iframe_count > 0:
iframe_texts = (
page.frame_locator("iframe")
.locator("html")
.all_inner_texts()
)
document_text = "\n".join(iframe_texts)
""" 700 is the threshold value for the length of the text extracted
from the iframe based on the issue faced """
if len(parsed_html.cleaned_text) < IFRAME_TEXT_LENGTH_THRESHOLD:
parsed_html.cleaned_text = document_text
else:
parsed_html.cleaned_text += "\n" + document_text
# Sometimes pages with #! will serve duplicate content
# There are also just other ways this can happen
hashed_text = hash((parsed_html.title, parsed_html.cleaned_text))
if hashed_text in content_hashes:
logger.info(
f"{index}: Skipping duplicate title + content for {initial_url}"
)
continue
content_hashes.add(hashed_text)
doc_batch.append(
Document(
id=current_url,
id=initial_url,
sections=[
Section(link=current_url, text=parsed_html.cleaned_text)
Section(link=initial_url, text=parsed_html.cleaned_text)
],
source=DocumentSource.WEB,
semantic_identifier=parsed_html.title or current_url,
semantic_identifier=parsed_html.title or initial_url,
metadata={},
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
@ -410,7 +457,7 @@ class WebConnector(LoadConnector):
page.close()
except Exception as e:
last_error = f"Failed to fetch '{current_url}': {e}"
last_error = f"Failed to fetch '{initial_url}': {e}"
logger.exception(last_error)
playwright.stop()
restart_playwright = True

View File

@ -76,6 +76,10 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
provider_type=search_settings.provider_type,
index_name=search_settings.index_name,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
# Whether switching to this model requires re-indexing
background_reindex_enabled=search_settings.background_reindex_enabled,
# Reranking Details
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,

View File

@ -168,7 +168,7 @@ def get_chat_sessions_by_user(
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
stmt = stmt.order_by(desc(ChatSession.time_updated))
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
@ -962,6 +962,7 @@ def translate_db_message_to_chat_message_detail(
chat_message.sub_questions
),
refined_answer_improvement=chat_message.refined_answer_improvement,
is_agentic=chat_message.is_agentic,
error=chat_message.error,
)

View File

@ -360,18 +360,13 @@ def backend_update_credential_json(
db_session.commit()
def delete_credential(
def _delete_credential_internal(
credential: Credential,
credential_id: int,
user: User | None,
db_session: Session,
force: bool = False,
) -> None:
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
raise ValueError(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
)
"""Internal utility function to handle the actual deletion of a credential"""
associated_connectors = (
db_session.query(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.credential_id == credential_id)
@ -416,6 +411,35 @@ def delete_credential(
db_session.commit()
def delete_credential_for_user(
credential_id: int,
user: User,
db_session: Session,
force: bool = False,
) -> None:
"""Delete a credential that belongs to a specific user"""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
raise ValueError(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
)
_delete_credential_internal(credential, credential_id, db_session, force)
def delete_credential(
credential_id: int,
db_session: Session,
force: bool = False,
) -> None:
"""Delete a credential regardless of ownership (admin function)"""
credential = fetch_credential_by_id(credential_id, db_session)
if credential is None:
raise ValueError(f"Credential by provided id {credential_id} does not exist")
_delete_credential_internal(credential, credential_id, db_session, force)
def create_initial_public_credential(db_session: Session) -> None:
error_msg = (
"DB is not in a valid initial state."

View File

@ -63,6 +63,9 @@ class IndexModelStatus(str, PyEnum):
PRESENT = "PRESENT"
FUTURE = "FUTURE"
def is_current(self) -> bool:
return self == IndexModelStatus.PRESENT
class ChatSessionSharedStatus(str, PyEnum):
PUBLIC = "public"
@ -83,3 +86,11 @@ class AccessType(str, PyEnum):
PUBLIC = "public"
PRIVATE = "private"
SYNC = "sync"
class EmbeddingPrecision(str, PyEnum):
# matches vespa tensor type
# only support float / bfloat16 for now, since there's not a
# good reason to specify anything else
BFLOAT16 = "bfloat16"
FLOAT = "float"

View File

@ -7,6 +7,7 @@ from typing import Optional
from uuid import uuid4
from pydantic import BaseModel
from sqlalchemy.orm import validates
from typing_extensions import TypedDict # noreorder
from uuid import UUID
@ -45,7 +46,13 @@ from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.db.enums import AccessType, IndexingMode, SyncType, SyncStatus
from onyx.db.enums import (
AccessType,
EmbeddingPrecision,
IndexingMode,
SyncType,
SyncStatus,
)
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.constants import TokenRateLimitScope
@ -206,6 +213,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
@validates("email")
def validate_email(self, key: str, value: str) -> str:
return value.lower() if value else value
@property
def password_configured(self) -> bool:
"""
@ -711,6 +722,23 @@ class SearchSettings(Base):
ForeignKey("embedding_provider.provider_type"), nullable=True
)
# Whether switching to this model should re-index all connectors in the background
# if no re-index is needed, will be ignored. Only used during the switch-over process.
background_reindex_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
# allows for quantization -> less memory usage for a small performance hit
embedding_precision: Mapped[EmbeddingPrecision] = mapped_column(
Enum(EmbeddingPrecision, native_enum=False)
)
# can be used to reduce dimensionality of vectors and save memory with
# a small performance hit. More details in the `Reducing embedding dimensions`
# section here:
# https://platform.openai.com/docs/guides/embeddings#embedding-models
# If not specified, will just use the model_dim without any reduction.
# NOTE: this is only currently available for OpenAI models
reduced_dimension: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Mini and Large Chunks (large chunk also checks for model max context)
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
@ -792,6 +820,12 @@ class SearchSettings(Base):
self.multipass_indexing, self.model_name, self.provider_type
)
@property
def final_embedding_dim(self) -> int:
if self.reduced_dimension:
return self.reduced_dimension
return self.model_dim
@staticmethod
def can_use_large_chunks(
multipass: bool, model_name: str, provider_type: EmbeddingProvider | None
@ -1756,6 +1790,7 @@ class ChannelConfig(TypedDict):
channel_name: str | None # None for default channel config
respond_tag_only: NotRequired[bool] # defaults to False
respond_to_bots: NotRequired[bool] # defaults to False
is_ephemeral: NotRequired[bool] # defaults to False
respond_member_group_list: NotRequired[list[str]]
answer_filters: NotRequired[list[AllowedAnswerFilters]]
# If None then no follow up
@ -2270,6 +2305,10 @@ class UserTenantMapping(Base):
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
@validates("email")
def validate_email(self, key: str, value: str) -> str:
return value.lower() if value else value
# This is a mapping from tenant IDs to anonymous user paths
class TenantAnonymousUserPath(Base):

View File

@ -209,13 +209,21 @@ def create_update_persona(
if not all_prompt_ids:
raise ValueError("No prompt IDs provided")
is_default_persona: bool | None = create_persona_request.is_default_persona
# Default persona validation
if create_persona_request.is_default_persona:
if not create_persona_request.is_public:
raise ValueError("Cannot make a default persona non public")
if user and user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
if user:
# Curators can edit default personas, but not make them
if (
user.role == UserRole.CURATOR
or user.role == UserRole.GLOBAL_CURATOR
):
is_default_persona = None
elif user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
persona = upsert_persona(
persona_id=persona_id,
@ -241,7 +249,7 @@ def create_update_persona(
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=create_persona_request.is_default_persona,
is_default_persona=is_default_persona,
)
versioned_make_persona_private = fetch_versioned_implementation(
@ -428,7 +436,7 @@ def upsert_persona(
remove_image: bool | None = None,
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool = False,
is_default_persona: bool | None = None,
label_ids: list[int] | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
@ -523,7 +531,11 @@ def upsert_persona(
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.labels = labels or []
existing_persona.is_default_persona = is_default_persona
existing_persona.is_default_persona = (
is_default_persona
if is_default_persona is not None
else existing_persona.is_default_persona
)
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@ -575,7 +587,9 @@ def upsert_persona(
display_priority=display_priority,
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona,
is_default_persona=is_default_persona
if is_default_persona is not None
else False,
labels=labels or [],
)
db_session.add(new_persona)

View File

@ -13,6 +13,7 @@ from onyx.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from onyx.context.search.models import SavedSearchSettings
from onyx.db.enums import EmbeddingPrecision
from onyx.db.llm import fetch_embedding_provider
from onyx.db.models import CloudEmbeddingProvider
from onyx.db.models import IndexAttempt
@ -59,12 +60,15 @@ def create_search_settings(
index_name=search_settings.index_name,
provider_type=search_settings.provider_type,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
multilingual_expansion=search_settings.multilingual_expansion,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
background_reindex_enabled=search_settings.background_reindex_enabled,
)
db_session.add(embedding_model)
@ -305,6 +309,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
embedding_precision=(EmbeddingPrecision.FLOAT),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
@ -322,6 +327,7 @@ def get_new_default_embedding_model() -> IndexingSetting:
return IndexingSetting(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
embedding_precision=(EmbeddingPrecision.FLOAT),
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,

View File

@ -8,10 +8,12 @@ from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.factory import get_default_document_index
from onyx.key_value_store.factory import get_kv_store
from onyx.utils.logger import setup_logger
@ -19,7 +21,49 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def check_index_swap(db_session: Session) -> SearchSettings | None:
def _perform_index_swap(
db_session: Session,
current_search_settings: SearchSettings,
secondary_search_settings: SearchSettings,
all_cc_pairs: list[ConnectorCredentialPair],
) -> None:
"""Swap the indices and expire the old one."""
current_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if len(all_cc_pairs) > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
# remove the old index from the vector db
document_index = get_default_document_index(secondary_search_settings, None)
document_index.ensure_indices_exist(
primary_embedding_dim=secondary_search_settings.final_embedding_dim,
primary_embedding_precision=secondary_search_settings.embedding_precision,
# just finished swap, no more secondary index
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one.
@ -27,52 +71,45 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
Returns None if search settings did not change, or the old search settings if they
did change.
"""
old_search_settings = None
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
search_settings = get_secondary_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
if not search_settings:
if not secondary_search_settings:
return None
# If the secondary search settings are not configured to reindex in the background,
# we can just swap over instantly
if not secondary_search_settings.background_reindex_enabled:
current_search_settings = get_current_search_settings(db_session)
_perform_index_swap(
db_session=db_session,
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
)
return current_search_settings
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
search_settings_id=secondary_search_settings.id, db_session=db_session
)
# Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this
# function is correct. The unique_cc_indexings are specifically for the existing cc-pairs
old_search_settings = None
if unique_cc_indexings > cc_pair_count:
logger.error("More unique indexings than cc pairs, should not occur")
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
# Swap indices
current_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
_perform_index_swap(
db_session=db_session,
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
)
update_search_settings_status(
search_settings=search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if cc_pair_count > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
old_search_settings = current_search_settings
old_search_settings = current_search_settings
return old_search_settings

View File

@ -6,6 +6,7 @@ from typing import Any
from onyx.access.models import DocumentAccess
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
from onyx.indexing.models import DocMetadataAwareIndexChunk
from shared_configs.model_server_models import Embedding
@ -145,17 +146,21 @@ class Verifiable(abc.ABC):
@abc.abstractmethod
def ensure_indices_exist(
self,
index_embedding_dim: int,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
) -> None:
"""
Verify that the document index exists and is consistent with the expectations in the code.
Parameters:
- index_embedding_dim: Vector dimensionality for the vector similarity part of the search
- primary_embedding_dim: Vector dimensionality for the vector similarity part of the search
- primary_embedding_precision: Precision of the vector similarity part of the search
- secondary_index_embedding_dim: Vector dimensionality of the secondary index being built
behind the scenes. The secondary index should only be built when switching
embedding models therefore this dim should be different from the primary index.
- secondary_index_embedding_precision: Precision of the vector similarity part of the secondary index
"""
raise NotImplementedError
@ -164,6 +169,7 @@ class Verifiable(abc.ABC):
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
embedding_precisions: list[EmbeddingPrecision],
) -> None:
"""
Register multitenant indices with the document index.

View File

@ -37,7 +37,7 @@ schema DANSWER_CHUNK_NAME {
summary: dynamic
}
# Title embedding (x1)
field title_embedding type tensor<float>(x[VARIABLE_DIM]) {
field title_embedding type tensor<EMBEDDING_PRECISION>(x[VARIABLE_DIM]) {
indexing: attribute | index
attribute {
distance-metric: angular
@ -45,7 +45,7 @@ schema DANSWER_CHUNK_NAME {
}
# Content embeddings (chunk + optional mini chunks embeddings)
# "t" and "x" are arbitrary names, not special keywords
field embeddings type tensor<float>(t{},x[VARIABLE_DIM]) {
field embeddings type tensor<EMBEDDING_PRECISION>(t{},x[VARIABLE_DIM]) {
indexing: attribute | index
attribute {
distance-metric: angular

View File

@ -5,4 +5,7 @@
<allow
until="DATE_REPLACEMENT"
comment="We need to be able to update the schema for updates to the Onyx schema">indexing-change</allow>
<allow
until='DATE_REPLACEMENT'
comment="Prevents old alt indices from interfering with changes">field-type-change</allow>
</validation-overrides>

View File

@ -310,6 +310,11 @@ def query_vespa(
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
f"Exception: {str(e)}"
+ (
f"\nResponse: {e.response.text}"
if isinstance(e, httpx.HTTPStatusError)
else ""
)
)
raise httpx.HTTPError(error_base) from e

View File

@ -26,6 +26,7 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.document_index_utils import get_document_chunk_ids
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentInsertionRecord
@ -63,6 +64,7 @@ from onyx.document_index.vespa_constants import DATE_REPLACEMENT
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import EMBEDDING_PRECISION_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
@ -112,6 +114,21 @@ def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str:
return "\n".join(doc_lines)
def _replace_template_values_in_schema(
schema_template: str,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
) -> str:
return (
schema_template.replace(
EMBEDDING_PRECISION_REPLACEMENT_PAT, embedding_precision.value
)
.replace(DANSWER_CHUNK_REPLACEMENT_PAT, index_name)
.replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
)
def add_ngrams_to_schema(schema_content: str) -> str:
# Add the match blocks containing gram and gram-size to title and content fields
schema_content = re.sub(
@ -163,8 +180,10 @@ class VespaIndex(DocumentIndex):
def ensure_indices_exist(
self,
index_embedding_dim: int,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
) -> None:
if MULTI_TENANT:
logger.info(
@ -221,18 +240,29 @@ class VespaIndex(DocumentIndex):
schema_template = schema_f.read()
schema_template = schema_template.replace(TENANT_ID_PAT, "")
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
schema = _replace_template_values_in_schema(
schema_template,
self.index_name,
primary_embedding_dim,
primary_embedding_precision,
)
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
if self.secondary_index_name:
upcoming_schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
if secondary_index_embedding_dim is None:
raise ValueError("Secondary index embedding dimension is required")
if secondary_index_embedding_precision is None:
raise ValueError("Secondary index embedding precision is required")
upcoming_schema = _replace_template_values_in_schema(
schema_template,
self.secondary_index_name,
secondary_index_embedding_dim,
secondary_index_embedding_precision,
)
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict)
@ -251,6 +281,7 @@ class VespaIndex(DocumentIndex):
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
embedding_precisions: list[EmbeddingPrecision],
) -> None:
if not MULTI_TENANT:
raise ValueError("Multi-tenant is not enabled")
@ -309,13 +340,14 @@ class VespaIndex(DocumentIndex):
for i, index_name in enumerate(indices):
embedding_dim = embedding_dims[i]
embedding_precision = embedding_precisions[i]
logger.info(
f"Creating index: {index_name} with embedding dimension: {embedding_dim}"
)
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
schema = _replace_template_values_in_schema(
schema_template, index_name, embedding_dim, embedding_precision
)
schema = schema.replace(
TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else ""
)

View File

@ -6,6 +6,7 @@ from onyx.configs.app_configs import VESPA_TENANT_PORT
from onyx.configs.constants import SOURCE_TYPE
VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
EMBEDDING_PRECISION_REPLACEMENT_PAT = "EMBEDDING_PRECISION"
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"

View File

@ -38,6 +38,7 @@ class IndexingEmbedder(ABC):
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
reduced_dimension: int | None,
callback: IndexingHeartbeatInterface | None,
):
self.model_name = model_name
@ -60,6 +61,7 @@ class IndexingEmbedder(ABC):
api_url=api_url,
api_version=api_version,
deployment_name=deployment_name,
reduced_dimension=reduced_dimension,
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
@ -87,6 +89,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
callback: IndexingHeartbeatInterface | None = None,
):
super().__init__(
@ -99,6 +102,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url,
api_version,
deployment_name,
reduced_dimension,
callback,
)
@ -219,6 +223,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
reduced_dimension=search_settings.reduced_dimension,
callback=callback,
)

View File

@ -5,6 +5,7 @@ from pydantic import Field
from onyx.access.models import DocumentAccess
from onyx.connectors.models import Document
from onyx.db.enums import EmbeddingPrecision
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
from shared_configs.model_server_models import Embedding
@ -143,10 +144,20 @@ class IndexingSetting(EmbeddingModelDetail):
model_dim: int
index_name: str | None
multipass_indexing: bool
embedding_precision: EmbeddingPrecision
reduced_dimension: int | None = None
background_reindex_enabled: bool = True
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
@property
def final_embedding_dim(self) -> int:
if self.reduced_dimension:
return self.reduced_dimension
return self.model_dim
@classmethod
def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting":
return cls(
@ -158,6 +169,9 @@ class IndexingSetting(EmbeddingModelDetail):
provider_type=search_settings.provider_type,
index_name=search_settings.index_name,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
background_reindex_enabled=search_settings.background_reindex_enabled,
)

View File

@ -19,6 +19,7 @@ from onyx.utils.special_types import JSON_ro
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()

View File

@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as oauth_router
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.folder.api import router as folder_router
from onyx.server.features.input_prompt.api import (
@ -323,7 +322,6 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
include_router_with_global_prefix_prepended(application, oauth_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@ -89,6 +89,7 @@ class EmbeddingModel:
callback: IndexingHeartbeatInterface | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@ -100,6 +101,7 @@ class EmbeddingModel:
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.reduced_dimension = reduced_dimension
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
@ -188,6 +190,7 @@ class EmbeddingModel:
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
api_url=self.api_url,
reduced_dimension=self.reduced_dimension,
)
start_time = time.time()
@ -300,6 +303,7 @@ class EmbeddingModel:
retrim_content=retrim_content,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
reduced_dimension=search_settings.reduced_dimension,
)

View File

@ -31,12 +31,18 @@ from onyx.onyxbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
from onyx.onyxbot.slack.formatting import format_slack_message
from onyx.onyxbot.slack.icons import source_to_github_img_link
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessage
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageChannelConfig
from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageMessageInfo
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import build_continue_in_web_ui_id
from onyx.onyxbot.slack.utils import build_feedback_id
from onyx.onyxbot.slack.utils import build_publish_ephemeral_message_id
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
from onyx.onyxbot.slack.utils import translate_vespa_highlight_to_slack
from onyx.utils.text_processing import decode_escapes
@ -105,6 +111,77 @@ def _build_qa_feedback_block(
)
def _build_ephemeral_publication_block(
channel_id: str,
chat_message_id: int,
message_info: SlackMessageInfo,
original_question_ts: str,
channel_conf: ChannelConfig,
feedback_reminder_id: str | None = None,
) -> Block:
# check whether the message is in a thread
if (
message_info is not None
and message_info.msg_to_respond is not None
and message_info.thread_to_respond is not None
and (message_info.msg_to_respond == message_info.thread_to_respond)
):
respond_ts = None
else:
respond_ts = original_question_ts
action_values_ephemeral_message_channel_config = (
ActionValuesEphemeralMessageChannelConfig(
channel_name=channel_conf.get("channel_name"),
respond_tag_only=channel_conf.get("respond_tag_only"),
respond_to_bots=channel_conf.get("respond_to_bots"),
is_ephemeral=channel_conf.get("is_ephemeral", False),
respond_member_group_list=channel_conf.get("respond_member_group_list"),
answer_filters=channel_conf.get("answer_filters"),
follow_up_tags=channel_conf.get("follow_up_tags"),
show_continue_in_web_ui=channel_conf.get("show_continue_in_web_ui", False),
)
)
action_values_ephemeral_message_message_info = (
ActionValuesEphemeralMessageMessageInfo(
bypass_filters=message_info.bypass_filters,
channel_to_respond=message_info.channel_to_respond,
msg_to_respond=message_info.msg_to_respond,
email=message_info.email,
sender_id=message_info.sender_id,
thread_messages=[],
is_bot_msg=message_info.is_bot_msg,
is_bot_dm=message_info.is_bot_dm,
thread_to_respond=respond_ts,
)
)
action_values_ephemeral_message = ActionValuesEphemeralMessage(
original_question_ts=original_question_ts,
feedback_reminder_id=feedback_reminder_id,
chat_message_id=chat_message_id,
message_info=action_values_ephemeral_message_message_info,
channel_conf=action_values_ephemeral_message_channel_config,
)
return ActionsBlock(
block_id=build_publish_ephemeral_message_id(original_question_ts),
elements=[
ButtonElement(
action_id=SHOW_EVERYONE_ACTION_ID,
text="📢 Share with Everyone",
value=action_values_ephemeral_message.model_dump_json(),
),
ButtonElement(
action_id=KEEP_TO_YOURSELF_ACTION_ID,
text="🤫 Keep to Yourself",
value=action_values_ephemeral_message.model_dump_json(),
),
],
)
def get_document_feedback_blocks() -> Block:
return SectionBlock(
text=(
@ -486,16 +563,21 @@ def build_slack_response_blocks(
use_citations: bool,
feedback_reminder_id: str | None,
skip_ai_feedback: bool = False,
offer_ephemeral_publication: bool = False,
expecting_search_result: bool = False,
skip_restated_question: bool = False,
) -> list[Block]:
"""
This function is a top level function that builds all the blocks for the Slack response.
It also handles combining all the blocks together.
"""
# If called with the OnyxBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
)
if not skip_restated_question:
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
)
else:
restate_question_block = []
if expecting_search_result:
answer_blocks = _build_qa_response_blocks(
@ -520,12 +602,36 @@ def build_slack_response_blocks(
)
follow_up_block = []
if channel_conf and channel_conf.get("follow_up_tags") is not None:
if (
channel_conf
and channel_conf.get("follow_up_tags") is not None
and not channel_conf.get("is_ephemeral", False)
):
follow_up_block.append(
_build_follow_up_block(message_id=answer.chat_message_id)
)
ai_feedback_block = []
publish_ephemeral_message_block = []
if (
offer_ephemeral_publication
and answer.chat_message_id is not None
and message_info.msg_to_respond is not None
and channel_conf is not None
):
publish_ephemeral_message_block.append(
_build_ephemeral_publication_block(
channel_id=message_info.channel_to_respond,
chat_message_id=answer.chat_message_id,
original_question_ts=message_info.msg_to_respond,
message_info=message_info,
channel_conf=channel_conf,
feedback_reminder_id=feedback_reminder_id,
)
)
ai_feedback_block: list[Block] = []
if answer.chat_message_id is not None and not skip_ai_feedback:
ai_feedback_block.append(
_build_qa_feedback_block(
@ -547,6 +653,7 @@ def build_slack_response_blocks(
all_blocks = (
restate_question_block
+ answer_blocks
+ publish_ephemeral_message_block
+ ai_feedback_block
+ citations_divider
+ citations_blocks

View File

@ -2,6 +2,8 @@ from enum import Enum
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
SHOW_EVERYONE_ACTION_ID = "show-everyone"
KEEP_TO_YOURSELF_ACTION_ID = "keep-to-yourself"
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"

View File

@ -1,3 +1,4 @@
import json
from typing import Any
from typing import cast
@ -5,21 +6,32 @@ from slack_sdk import WebClient
from slack_sdk.models.blocks import SectionBlock
from slack_sdk.models.views import View
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.webhook import WebhookClient
from onyx.chat.models import ChatOnyxBotResponse
from onyx.chat.models import CitationInfo
from onyx.chat.models import QADocsResponse
from onyx.configs.constants import MessageType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import get_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.feedback import create_chat_message_feedback
from onyx.db.feedback import create_doc_retrieval_feedback
from onyx.db.session import get_session_with_current_tenant
from onyx.db.users import get_user_by_email
from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
from onyx.onyxbot.slack.blocks import get_document_feedback_blocks
from onyx.onyxbot.slack.config import get_slack_channel_config_for_bot_and_channel
from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from onyx.onyxbot.slack.constants import FeedbackVisibility
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID
from onyx.onyxbot.slack.handlers.handle_message import (
remove_scheduled_feedback_reminder,
@ -35,15 +47,48 @@ from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails
from onyx.onyxbot.slack.utils import get_channel_name_from_id
from onyx.onyxbot.slack.utils import get_feedback_visibility
from onyx.onyxbot.slack.utils import read_slack_thread
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import TenantSocketModeClient
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _convert_db_doc_id_to_document_ids(
citation_dict: dict[int, int], top_documents: list[SavedSearchDoc]
) -> list[CitationInfo]:
citation_list_with_document_id = []
for citation_num, db_doc_id in citation_dict.items():
if db_doc_id is not None:
matching_doc = next(
(d for d in top_documents if d.db_doc_id == db_doc_id), None
)
if matching_doc:
citation_list_with_document_id.append(
CitationInfo(
citation_num=citation_num, document_id=matching_doc.document_id
)
)
return citation_list_with_document_id
def _build_citation_list(chat_message_detail: ChatMessageDetail) -> list[CitationInfo]:
citation_dict = chat_message_detail.citations
if citation_dict is None:
return []
else:
top_documents = (
chat_message_detail.context_docs.top_documents
if chat_message_detail.context_docs
else []
)
citation_list = _convert_db_doc_id_to_document_ids(citation_dict, top_documents)
return citation_list
def handle_doc_feedback_button(
req: SocketModeRequest,
client: TenantSocketModeClient,
@ -58,7 +103,7 @@ def handle_doc_feedback_button(
external_id = build_feedback_id(query_event_id, doc_id, doc_rank)
channel_id = req.payload["container"]["channel_id"]
thread_ts = req.payload["container"]["thread_ts"]
thread_ts = req.payload["container"].get("thread_ts", None)
data = View(
type="modal",
@ -84,7 +129,7 @@ def handle_generate_answer_button(
channel_id = req.payload["channel"]["id"]
channel_name = req.payload["channel"]["name"]
message_ts = req.payload["message"]["ts"]
thread_ts = req.payload["container"]["thread_ts"]
thread_ts = req.payload["container"].get("thread_ts", None)
user_id = req.payload["user"]["id"]
expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={})
email = expert_info.email if expert_info else None
@ -106,7 +151,7 @@ def handle_generate_answer_button(
# tell the user that we're working on it
# Send an ephemeral message to the user that we're generating the answer
respond_in_thread(
respond_in_thread_or_channel(
client=client.web_client,
channel=channel_id,
receiver_ids=[user_id],
@ -142,6 +187,178 @@ def handle_generate_answer_button(
)
def handle_publish_ephemeral_message_button(
req: SocketModeRequest,
client: TenantSocketModeClient,
action_id: str,
) -> None:
"""
This function handles the Share with Everyone/Keep for Yourself buttons
for ephemeral messages.
"""
channel_id = req.payload["channel"]["id"]
ephemeral_message_ts = req.payload["container"]["message_ts"]
slack_sender_id = req.payload["user"]["id"]
response_url = req.payload["response_url"]
webhook = WebhookClient(url=response_url)
# The additional data required that was added to buttons.
# Specifically, this contains the message_info, channel_conf information
# and some additional attributes.
value_dict = json.loads(req.payload["actions"][0]["value"])
original_question_ts = value_dict.get("original_question_ts")
if not original_question_ts:
raise ValueError("Missing original_question_ts in the payload")
if not ephemeral_message_ts:
raise ValueError("Missing ephemeral_message_ts in the payload")
feedback_reminder_id = value_dict.get("feedback_reminder_id")
slack_message_info = SlackMessageInfo(**value_dict["message_info"])
channel_conf = value_dict.get("channel_conf")
user_email = value_dict.get("message_info", {}).get("email")
chat_message_id = value_dict.get("chat_message_id")
# Obtain onyx_user and chat_message information
if not chat_message_id:
raise ValueError("Missing chat_message_id in the payload")
with get_session_with_current_tenant() as db_session:
onyx_user = get_user_by_email(user_email, db_session)
if not onyx_user:
raise ValueError("Cannot determine onyx_user_id from email in payload")
try:
chat_message = get_chat_message(chat_message_id, onyx_user.id, db_session)
except ValueError:
chat_message = get_chat_message(
chat_message_id, None, db_session
) # is this good idea?
except Exception as e:
logger.error(f"Failed to get chat message: {e}")
raise e
chat_message_detail = translate_db_message_to_chat_message_detail(chat_message)
# construct the proper citation format and then the answer in the suitable format
# we need to construct the blocks.
citation_list = _build_citation_list(chat_message_detail)
onyx_bot_answer = ChatOnyxBotResponse(
answer=chat_message_detail.message,
citations=citation_list,
chat_message_id=chat_message_id,
docs=QADocsResponse(
top_documents=chat_message_detail.context_docs.top_documents
if chat_message_detail.context_docs
else [],
predicted_flow=None,
predicted_search=None,
applied_source_filters=None,
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
),
llm_selected_doc_indices=None,
error_msg=None,
)
# Note: we need to use the webhook and the respond_url to update/delete ephemeral messages
if action_id == SHOW_EVERYONE_ACTION_ID:
# Convert to non-ephemeral message in thread
try:
webhook.send(
response_type="ephemeral",
text="",
blocks=[],
replace_original=True,
delete_original=True,
)
except Exception as e:
logger.error(f"Failed to send webhook: {e}")
# remove handling of empheremal block and add AI feedback.
all_blocks = build_slack_response_blocks(
answer=onyx_bot_answer,
message_info=slack_message_info,
channel_conf=channel_conf,
use_citations=True,
feedback_reminder_id=feedback_reminder_id,
skip_ai_feedback=False,
offer_ephemeral_publication=False,
skip_restated_question=True,
)
try:
# Post in thread as non-ephemeral message
respond_in_thread_or_channel(
client=client.web_client,
channel=channel_id,
receiver_ids=None, # If respond_member_group_list is set, send to them. TODO: check!
text="Hello! Onyx has some results for you!",
blocks=all_blocks,
thread_ts=original_question_ts,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
send_as_ephemeral=False,
)
except Exception as e:
logger.error(f"Failed to publish ephemeral message: {e}")
raise e
elif action_id == KEEP_TO_YOURSELF_ACTION_ID:
# Keep as ephemeral message in channel or thread, but remove the publish button and add feedback button
changed_blocks = build_slack_response_blocks(
answer=onyx_bot_answer,
message_info=slack_message_info,
channel_conf=channel_conf,
use_citations=True,
feedback_reminder_id=feedback_reminder_id,
skip_ai_feedback=False,
offer_ephemeral_publication=False,
skip_restated_question=True,
)
try:
if slack_message_info.thread_to_respond is not None:
# There seems to be a bug in slack where an update within the thread
# actually leads to the update to be posted in the channel. Therefore,
# for now we delete the original ephemeral message and post a new one
# if the ephemeral message is in a thread.
webhook.send(
response_type="ephemeral",
text="",
blocks=[],
replace_original=True,
delete_original=True,
)
respond_in_thread_or_channel(
client=client.web_client,
channel=channel_id,
receiver_ids=[slack_sender_id],
text="Your personal response, sent as an ephemeral message.",
blocks=changed_blocks,
thread_ts=original_question_ts,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
send_as_ephemeral=True,
)
else:
# This works fine if the ephemeral message is in the channel
webhook.send(
response_type="ephemeral",
text="Your personal response, sent as an ephemeral message.",
blocks=changed_blocks,
replace_original=True,
delete_original=False,
)
except Exception as e:
logger.error(f"Failed to send webhook: {e}")
def handle_slack_feedback(
feedback_id: str,
feedback_type: str,
@ -153,13 +370,20 @@ def handle_slack_feedback(
) -> None:
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
# Get Onyx user from Slack ID
expert_info = expert_info_from_slack_id(
user_id_to_post_confirmation, client, user_cache={}
)
email = expert_info.email if expert_info else None
with get_session_with_current_tenant() as db_session:
onyx_user = get_user_by_email(email, db_session) if email else None
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
create_chat_message_feedback(
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
feedback_text="",
chat_message_id=message_id,
user_id=None, # no "user" for Slack bot for now
user_id=onyx_user.id if onyx_user else None,
db_session=db_session,
)
remove_scheduled_feedback_reminder(
@ -213,7 +437,7 @@ def handle_slack_feedback(
else:
msg = f"<@{user_id_to_post_confirmation}> has {feedback_response_txt} the AI Answer"
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel_id_to_post_confirmation,
text=msg,
@ -232,7 +456,7 @@ def handle_followup_button(
action_id = cast(str, action.get("block_id"))
channel_id = req.payload["container"]["channel_id"]
thread_ts = req.payload["container"]["thread_ts"]
thread_ts = req.payload["container"].get("thread_ts", None)
update_emote_react(
emoji=DANSWER_FOLLOWUP_EMOJI,
@ -265,7 +489,7 @@ def handle_followup_button(
blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids)
respond_in_thread(
respond_in_thread_or_channel(
client=client.web_client,
channel=channel_id,
text="Received your request for more help",
@ -315,7 +539,7 @@ def handle_followup_resolved_button(
) -> None:
channel_id = req.payload["container"]["channel_id"]
message_ts = req.payload["container"]["message_ts"]
thread_ts = req.payload["container"]["thread_ts"]
thread_ts = req.payload["container"].get("thread_ts", None)
clicker_name = get_clicker_name(req, client)
@ -349,7 +573,7 @@ def handle_followup_resolved_button(
resolved_block = SectionBlock(text=msg_text)
respond_in_thread(
respond_in_thread_or_channel(
client=client.web_client,
channel=channel_id,
text="Your request for help as been addressed!",

View File

@ -18,7 +18,7 @@ from onyx.onyxbot.slack.handlers.handle_standard_answers import (
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails
from onyx.onyxbot.slack.utils import fetch_user_ids_from_groups
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import slack_usage_report
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.utils.logger import setup_logger
@ -29,7 +29,7 @@ logger_base = setup_logger()
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
if details.is_bot_msg and details.sender_id:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=details.channel_to_respond,
thread_ts=details.msg_to_respond,
@ -202,7 +202,7 @@ def handle_message(
# which would just respond to the sender
if send_to and is_bot_msg:
if sender_id:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=[sender_id],
@ -220,6 +220,7 @@ def handle_message(
add_slack_user_if_not_exists(db_session, message_info.email)
# first check if we need to respond with a standard answer
# standard answers should be published in a thread
used_standard_answer = handle_standard_answers(
message_info=message_info,
receiver_ids=send_to,

View File

@ -33,7 +33,7 @@ from onyx.onyxbot.slack.blocks import build_slack_response_blocks
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.handlers.utils import slackify_message_thread
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import SlackRateLimiter
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.server.query_and_chat.models import CreateChatMessageRequest
@ -82,12 +82,38 @@ def handle_regular_answer(
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
# Capture whether response mode for channel is ephemeral. Even if the channel is set
# to respond with an ephemeral message, we still send as non-ephemeral if
# the message is a dm with the Onyx bot.
send_as_ephemeral = (
slack_channel_config.channel_config.get("is_ephemeral", False)
and not message_info.is_bot_dm
)
# If the channel mis configured to respond with an ephemeral message,
# or the message is a dm to the Onyx bot, we should use the proper onyx user from the email.
# This will make documents privately accessible to the user available to Onyx Bot answers.
# Otherwise - if not ephemeral or DM to Onyx Bot - we must use None as the user to restrict
# to public docs.
user = None
if message_info.is_bot_dm:
if message_info.is_bot_dm or send_as_ephemeral:
if message_info.email:
with get_session_with_current_tenant() as db_session:
user = get_user_by_email(message_info.email, db_session)
target_thread_ts = (
None
if send_as_ephemeral and len(message_info.thread_messages) < 2
else message_ts_to_respond_to
)
target_receiver_ids = (
[message_info.sender_id]
if message_info.sender_id and send_as_ephemeral
else receiver_ids
)
document_set_names: list[str] | None = None
prompt = None
# If no persona is specified, use the default search based persona
@ -134,11 +160,10 @@ def handle_regular_answer(
history_messages = messages[:-1]
single_message_history = slackify_message_thread(history_messages) or None
# Always check for ACL permissions, also for documnt sets that were explicitly added
# to the Bot by the Administrator. (Change relative to earlier behavior where all documents
# in an attached document set were available to all users in the channel.)
bypass_acl = False
if slack_channel_config.persona and slack_channel_config.persona.document_sets:
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
if not message_ts_to_respond_to and not is_bot_msg:
# if the message is not "/onyx" command, then it should have a message ts to respond to
@ -219,12 +244,13 @@ def handle_regular_answer(
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=None,
receiver_ids=target_receiver_ids,
text=f"Encountered exception when trying to answer: \n\n```{e}```",
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
send_as_ephemeral=send_as_ephemeral,
)
# In case of failures, don't keep the reaction there permanently
@ -242,32 +268,36 @@ def handle_regular_answer(
if answer is None:
assert DISABLE_GENERATIVE_AI is True
try:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=receiver_ids,
receiver_ids=target_receiver_ids,
text="Hello! Onyx has some results for you!",
blocks=[
SectionBlock(
text="Onyx is down for maintenance.\nWe're working hard on recharging the AI!"
)
],
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
send_as_ephemeral=send_as_ephemeral,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
if receiver_ids:
respond_in_thread(
# If the channel is ephemeral, we don't need to send a message to the user since they will already see the message
if target_receiver_ids and not send_as_ephemeral:
respond_in_thread_or_channel(
client=client,
channel=channel,
text=(
"👋 Hi, we've just gathered and forwarded the relevant "
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
send_as_ephemeral=send_as_ephemeral,
)
return False
@ -316,12 +346,13 @@ def handle_regular_answer(
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=None,
receiver_ids=target_receiver_ids,
text="Found no documents when trying to answer. Did you index any documents?",
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
send_as_ephemeral=send_as_ephemeral,
)
return True
@ -349,15 +380,27 @@ def handle_regular_answer(
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=None,
receiver_ids=target_receiver_ids,
text="Found no citations or quotes when trying to answer.",
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
send_as_ephemeral=send_as_ephemeral,
)
return True
if (
send_as_ephemeral
and target_receiver_ids is not None
and len(target_receiver_ids) == 1
):
offer_ephemeral_publication = True
skip_ai_feedback = True
else:
offer_ephemeral_publication = False
skip_ai_feedback = False if feedback_reminder_id else True
all_blocks = build_slack_response_blocks(
message_info=message_info,
answer=answer,
@ -365,31 +408,39 @@ def handle_regular_answer(
use_citations=True, # No longer supporting quotes
feedback_reminder_id=feedback_reminder_id,
expecting_search_result=expecting_search_result,
offer_ephemeral_publication=offer_ephemeral_publication,
skip_ai_feedback=skip_ai_feedback,
)
try:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=[message_info.sender_id]
if message_info.is_bot_msg and message_info.sender_id
else receiver_ids,
receiver_ids=target_receiver_ids,
text="Hello! Onyx has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
# don't unfurl, since otherwise we will have 5+ previews which makes the message very long
unfurl=False,
send_as_ephemeral=send_as_ephemeral,
)
# For DM (ephemeral message), we need to create a thread via a normal message so the user can see
# the ephemeral message. This also will give the user a notification which ephemeral message does not.
# if there is no message_ts_to_respond_to, and we have made it this far, then this is a /onyx message
# so we shouldn't send_team_member_message
if receiver_ids and message_ts_to_respond_to is not None:
if (
target_receiver_ids
and message_ts_to_respond_to is not None
and not send_as_ephemeral
and target_thread_ts is not None
):
send_team_member_message(
client=client,
channel=channel,
thread_ts=message_ts_to_respond_to,
thread_ts=target_thread_ts,
receiver_ids=target_receiver_ids,
send_as_ephemeral=send_as_ephemeral,
)
return False

View File

@ -2,7 +2,7 @@ from slack_sdk import WebClient
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import MessageType
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
@ -32,8 +32,10 @@ def send_team_member_message(
client: WebClient,
channel: str,
thread_ts: str,
receiver_ids: list[str] | None = None,
send_as_ephemeral: bool = False,
) -> None:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
text=(
@ -41,4 +43,6 @@ def send_team_member_message(
+ "information to the team. They'll get back to you shortly!"
),
thread_ts=thread_ts,
receiver_ids=None,
send_as_ephemeral=send_as_ephemeral,
)

View File

@ -57,7 +57,9 @@ from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID
from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID
from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID
from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID
from onyx.onyxbot.slack.handlers.handle_buttons import handle_doc_feedback_button
from onyx.onyxbot.slack.handlers.handle_buttons import handle_followup_button
@ -67,6 +69,9 @@ from onyx.onyxbot.slack.handlers.handle_buttons import (
from onyx.onyxbot.slack.handlers.handle_buttons import (
handle_generate_answer_button,
)
from onyx.onyxbot.slack.handlers.handle_buttons import (
handle_publish_ephemeral_message_button,
)
from onyx.onyxbot.slack.handlers.handle_buttons import handle_slack_feedback
from onyx.onyxbot.slack.handlers.handle_message import handle_message
from onyx.onyxbot.slack.handlers.handle_message import (
@ -81,7 +86,7 @@ from onyx.onyxbot.slack.utils import get_onyx_bot_slack_bot_id
from onyx.onyxbot.slack.utils import read_slack_thread
from onyx.onyxbot.slack.utils import remove_onyx_bot_tag
from onyx.onyxbot.slack.utils import rephrase_slack_message
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import TenantSocketModeClient
from onyx.redis.redis_pool import get_redis_client
from onyx.server.manage.models import SlackBotTokens
@ -667,7 +672,11 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
feedback_msg_reminder = cast(str, action.get("value"))
feedback_id = cast(str, action.get("block_id"))
channel_id = cast(str, req.payload["container"]["channel_id"])
thread_ts = cast(str, req.payload["container"]["thread_ts"])
thread_ts = cast(
str,
req.payload["container"].get("thread_ts")
or req.payload["container"].get("message_ts"),
)
else:
logger.error("Unable to process feedback. Action not found")
return
@ -783,7 +792,7 @@ def apologize_for_fail(
details: SlackMessageInfo,
client: TenantSocketModeClient,
) -> None:
respond_in_thread(
respond_in_thread_or_channel(
client=client.web_client,
channel=details.channel_to_respond,
thread_ts=details.msg_to_respond,
@ -859,6 +868,14 @@ def action_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> No
if action["action_id"] in [DISLIKE_BLOCK_ACTION_ID, LIKE_BLOCK_ACTION_ID]:
# AI Answer feedback
return process_feedback(req, client)
elif action["action_id"] in [
SHOW_EVERYONE_ACTION_ID,
KEEP_TO_YOURSELF_ACTION_ID,
]:
# Publish ephemeral message or keep hidden in main channel
return handle_publish_ephemeral_message_button(
req, client, action["action_id"]
)
elif action["action_id"] == FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID:
# Activation of the "source feedback" button
return handle_doc_feedback_button(req, client)

View File

@ -1,3 +1,5 @@
from typing import Literal
from pydantic import BaseModel
from onyx.chat.models import ThreadMessage
@ -13,3 +15,37 @@ class SlackMessageInfo(BaseModel):
bypass_filters: bool # User has tagged @OnyxBot
is_bot_msg: bool # User is using /OnyxBot
is_bot_dm: bool # User is direct messaging to OnyxBot
# Models used to encode the relevant data for the ephemeral message actions
class ActionValuesEphemeralMessageMessageInfo(BaseModel):
bypass_filters: bool | None
channel_to_respond: str | None
msg_to_respond: str | None
email: str | None
sender_id: str | None
thread_messages: list[ThreadMessage] | None
is_bot_msg: bool | None
is_bot_dm: bool | None
thread_to_respond: str | None
class ActionValuesEphemeralMessageChannelConfig(BaseModel):
channel_name: str | None
respond_tag_only: bool | None
respond_to_bots: bool | None
is_ephemeral: bool
respond_member_group_list: list[str] | None
answer_filters: list[
Literal["well_answered_postfilter", "questionmark_prefilter"]
] | None
follow_up_tags: list[str] | None
show_continue_in_web_ui: bool
class ActionValuesEphemeralMessage(BaseModel):
original_question_ts: str | None
feedback_reminder_id: str | None
chat_message_id: int
message_info: ActionValuesEphemeralMessageMessageInfo
channel_conf: ActionValuesEphemeralMessageChannelConfig

View File

@ -184,7 +184,7 @@ def _build_error_block(error_message: str) -> Block:
backoff=2,
logger=cast(logging.Logger, logger),
)
def respond_in_thread(
def respond_in_thread_or_channel(
client: WebClient,
channel: str,
thread_ts: str | None,
@ -193,6 +193,7 @@ def respond_in_thread(
receiver_ids: list[str] | None = None,
metadata: Metadata | None = None,
unfurl: bool = True,
send_as_ephemeral: bool | None = True,
) -> list[str]:
if not text and not blocks:
raise ValueError("One of `text` or `blocks` must be provided")
@ -236,6 +237,7 @@ def respond_in_thread(
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
try:
response = slack_call(
@ -299,6 +301,12 @@ def build_feedback_id(
return unique_prefix + ID_SEPARATOR + feedback_id
def build_publish_ephemeral_message_id(
original_question_ts: str,
) -> str:
return "publish_ephemeral_message__" + original_question_ts
def build_continue_in_web_ui_id(
message_id: int,
) -> str:
@ -539,7 +547,7 @@ def read_slack_thread(
# If auto-detected filters are on, use the second block for the actual answer
# The first block is the auto-detected filters
if message.startswith("_Filters"):
if message is not None and message.startswith("_Filters"):
if len(blocks) < 2:
logger.warning(f"Only filter blocks found: {reply}")
continue
@ -611,7 +619,7 @@ class SlackRateLimiter:
def notify(
self, client: WebClient, channel: str, position: int, thread_ts: str | None
) -> None:
respond_in_thread(
respond_in_thread_or_channel(
client=client,
channel=channel,
receiver_ids=None,

View File

@ -13,6 +13,7 @@ from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import create_credential
from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
from onyx.db.credentials import delete_credential
from onyx.db.credentials import delete_credential_for_user
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import fetch_credentials_by_source_for_user
from onyx.db.credentials import fetch_credentials_for_user
@ -88,7 +89,7 @@ def delete_credential_by_id_admin(
db_session: Session = Depends(get_session),
) -> StatusResponse:
"""Same as the user endpoint, but can delete any credential (not just the user's own)"""
delete_credential(db_session=db_session, credential_id=credential_id, user=None)
delete_credential(db_session=db_session, credential_id=credential_id)
return StatusResponse(
success=True, message="Credential deleted successfully", data=credential_id
)
@ -242,7 +243,7 @@ def delete_credential_by_id(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StatusResponse:
delete_credential(
delete_credential_for_user(
credential_id,
user,
db_session,
@ -259,7 +260,7 @@ def force_delete_credential_by_id(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StatusResponse:
delete_credential(credential_id, user, db_session, True)
delete_credential_for_user(credential_id, user, db_session, True)
return StatusResponse(
success=True, message="Credential deleted successfully", data=credential_id

View File

@ -49,6 +49,7 @@ def get_folders(
name=chat_session.description,
persona_id=chat_session.persona_id,
time_created=chat_session.time_created.isoformat(),
time_updated=chat_session.time_updated.isoformat(),
shared_status=chat_session.shared_status,
folder_id=folder.id,
)

View File

@ -181,6 +181,7 @@ class SlackChannelConfigCreationRequest(BaseModel):
channel_name: str
respond_tag_only: bool = False
respond_to_bots: bool = False
is_ephemeral: bool = False
show_continue_in_web_ui: bool = False
enable_auto_filters: bool = False
# If no team members, assume respond in the channel to everyone

View File

@ -72,11 +72,13 @@ def set_new_search_settings(
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
):
index_name += ALT_INDEX_SUFFIX
search_values = search_settings_new.dict()
search_values = search_settings_new.model_dump()
search_values["index_name"] = index_name
new_search_settings_request = SavedSearchSettings(**search_values)
else:
new_search_settings_request = SavedSearchSettings(**search_settings_new.dict())
new_search_settings_request = SavedSearchSettings(
**search_settings_new.model_dump()
)
secondary_search_settings = get_secondary_search_settings(db_session)
@ -103,8 +105,10 @@ def set_new_search_settings(
document_index = get_default_document_index(search_settings, new_search_settings)
document_index.ensure_indices_exist(
index_embedding_dim=search_settings.model_dim,
secondary_index_embedding_dim=new_search_settings.model_dim,
primary_embedding_dim=search_settings.final_embedding_dim,
primary_embedding_precision=search_settings.embedding_precision,
secondary_index_embedding_dim=new_search_settings.final_embedding_dim,
secondary_index_embedding_precision=new_search_settings.embedding_precision,
)
# Pause index attempts for the currently in use index to preserve resources
@ -137,6 +141,17 @@ def cancel_new_embedding(
db_session=db_session,
)
# remove the old index from the vector db
primary_search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(primary_search_settings, None)
document_index.ensure_indices_exist(
primary_embedding_dim=primary_search_settings.final_embedding_dim,
primary_embedding_precision=primary_search_settings.embedding_precision,
# just finished swap, no more secondary index
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
@router.delete("/delete-search-settings")
def delete_search_settings_endpoint(

View File

@ -71,6 +71,15 @@ def _form_channel_config(
"also respond to a predetermined set of users."
)
if (
slack_channel_config_creation_request.is_ephemeral
and slack_channel_config_creation_request.respond_member_group_list
):
raise ValueError(
"Cannot set OnyxBot to respond to users in a private (ephemeral) message "
"and also respond to a selected list of users."
)
channel_config: ChannelConfig = {
"channel_name": cleaned_channel_name,
}
@ -91,6 +100,8 @@ def _form_channel_config(
"respond_to_bots"
] = slack_channel_config_creation_request.respond_to_bots
channel_config["is_ephemeral"] = slack_channel_config_creation_request.is_ephemeral
channel_config["disabled"] = slack_channel_config_creation_request.disabled
return channel_config
@ -343,7 +354,8 @@ def list_bot_configs(
]
MAX_CHANNELS = 200
MAX_SLACK_PAGES = 5
SLACK_API_CHANNELS_PER_PAGE = 100
@router.get(
@ -355,8 +367,8 @@ def get_all_channels_from_slack_api(
_: User | None = Depends(current_admin_user),
) -> list[SlackChannel]:
"""
Fetches all channels from the Slack API.
If the workspace has 200 or more channels, we raise an error.
Fetches channels the bot is a member of from the Slack API.
Handles pagination with a limit to avoid excessive API calls.
"""
tokens = fetch_slack_bot_tokens(db_session, bot_id)
if not tokens or "bot_token" not in tokens:
@ -365,28 +377,60 @@ def get_all_channels_from_slack_api(
)
client = WebClient(token=tokens["bot_token"])
all_channels = []
next_cursor = None
current_page = 0
try:
response = client.conversations_list(
types="public_channel,private_channel",
exclude_archived=True,
limit=MAX_CHANNELS,
)
# Use users_conversations with limited pagination
while current_page < MAX_SLACK_PAGES:
current_page += 1
# Make API call with cursor if we have one
if next_cursor:
response = client.users_conversations(
types="public_channel,private_channel",
exclude_archived=True,
cursor=next_cursor,
limit=SLACK_API_CHANNELS_PER_PAGE,
)
else:
response = client.users_conversations(
types="public_channel,private_channel",
exclude_archived=True,
limit=SLACK_API_CHANNELS_PER_PAGE,
)
# Add channels to our list
if "channels" in response and response["channels"]:
all_channels.extend(response["channels"])
# Check if we need to paginate
if (
"response_metadata" in response
and "next_cursor" in response["response_metadata"]
):
next_cursor = response["response_metadata"]["next_cursor"]
if next_cursor:
if current_page == MAX_SLACK_PAGES:
raise HTTPException(
status_code=400,
detail="Workspace has too many channels to paginate over in this call.",
)
continue
# If we get here, no more pages
break
channels = [
SlackChannel(id=channel["id"], name=channel["name"])
for channel in response["channels"]
for channel in all_channels
]
if len(channels) == MAX_CHANNELS:
raise HTTPException(
status_code=400,
detail=f"Workspace has {MAX_CHANNELS} or more channels.",
)
return channels
except SlackApiError as e:
# Handle rate limiting or other API errors
raise HTTPException(
status_code=500,
detail=f"Error fetching channels from Slack API: {str(e)}",

View File

@ -147,9 +147,11 @@ def list_threads(
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
current_temperature_override=chat.temperature_override,
)
for chat in chat_sessions
]

View File

@ -119,6 +119,7 @@ def get_user_chat_sessions(
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,

View File

@ -181,6 +181,7 @@ class ChatSessionDetails(BaseModel):
name: str | None
persona_id: int | None = None
time_created: str
time_updated: str
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
current_alternate_model: str | None = None
@ -241,6 +242,7 @@ class ChatMessageDetail(BaseModel):
files: list[FileDescriptor]
tool_call: ToolCallFinalResult | None
refined_answer_improvement: bool | None = None
is_agentic: bool | None = None
error: str | None = None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore

View File

@ -159,6 +159,7 @@ def get_user_search_sessions(
name=sessions_with_documents_dict[search.id],
persona_id=search.persona_id,
time_created=search.time_created.isoformat(),
time_updated=search.time_updated.isoformat(),
shared_status=search.shared_status,
folder_id=search.folder_id,
current_alternate_model=search.current_alternate_model,

View File

@ -46,13 +46,21 @@ def mask_string(sensitive_str: str) -> str:
return "****...**" + sensitive_str[-4:]
MASK_CREDENTIALS_WHITELIST = {
DB_CREDENTIALS_AUTHENTICATION_METHOD,
"wiki_base",
"cloud_name",
"cloud_id",
}
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
masked_creds = {}
for key, val in credential_dict.items():
if isinstance(val, str):
# we want to pass the authentication_method field through so the frontend
# can disambiguate credentials created by different methods
if key == DB_CREDENTIALS_AUTHENTICATION_METHOD:
if key in MASK_CREDENTIALS_WHITELIST:
masked_creds[key] = val
else:
masked_creds[key] = mask_string(val)
@ -63,8 +71,8 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
continue
raise ValueError(
f"Unable to mask credentials of type other than string, cannot process request."
f"Recieved type: {type(val)}"
f"Unable to mask credentials of type other than string or int, cannot process request."
f"Received type: {type(val)}"
)
return masked_creds

View File

@ -21,6 +21,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.connector_credential_pair import resync_cc_pair
from onyx.db.credentials import create_initial_public_credential
from onyx.db.document import check_docs_exist
from onyx.db.enums import EmbeddingPrecision
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_default_provider
@ -32,7 +33,7 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_current_search_settings
from onyx.db.search_settings import update_secondary_search_settings
from onyx.db.swap_index import check_index_swap
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.vespa.index import VespaIndex
@ -73,7 +74,7 @@ def setup_onyx(
The Tenant Service calls the tenants/create endpoint which runs this.
"""
check_index_swap(db_session=db_session)
check_and_perform_index_swap(db_session=db_session)
active_search_settings = get_active_search_settings(db_session)
search_settings = active_search_settings.primary
@ -243,10 +244,18 @@ def setup_vespa(
try:
logger.notice(f"Setting up Vespa (attempt {x+1}/{num_attempts})...")
document_index.ensure_indices_exist(
index_embedding_dim=index_setting.model_dim,
secondary_index_embedding_dim=secondary_index_setting.model_dim
if secondary_index_setting
else None,
primary_embedding_dim=index_setting.final_embedding_dim,
primary_embedding_precision=index_setting.embedding_precision,
secondary_index_embedding_dim=(
secondary_index_setting.final_embedding_dim
if secondary_index_setting
else None
),
secondary_index_embedding_precision=(
secondary_index_setting.embedding_precision
if secondary_index_setting
else None
),
)
logger.notice("Vespa setup complete.")
@ -360,6 +369,11 @@ def setup_vespa_multitenant(supported_indices: list[SupportedEmbeddingModel]) ->
],
embedding_dims=[index.dim for index in supported_indices]
+ [index.dim for index in supported_indices],
# on the cloud, just use float for all indices, the option to change this
# is not exposed to the user
embedding_precisions=[
EmbeddingPrecision.FLOAT for _ in range(len(supported_indices) * 2)
],
)
logger.notice("Vespa setup complete.")

View File

@ -136,7 +136,7 @@ def seed_dummy_docs(
search_settings = get_current_search_settings(db_session)
multipass_config = get_multipass_config(search_settings)
index_name = search_settings.index_name
embedding_dim = search_settings.model_dim
embedding_dim = search_settings.final_embedding_dim
vespa_index = VespaIndex(
index_name=index_name,

View File

@ -30,6 +30,12 @@ class EmbedRequest(BaseModel):
manual_passage_prefix: str | None = None
api_url: str | None = None
api_version: str | None = None
# allows for the truncation of the vector to a lower dimension
# to reduce memory usage. Currently only supported for OpenAI models.
# will be ignored for other providers.
reduced_dimension: int | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}

View File

@ -5,7 +5,9 @@ from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
from onyx.connectors.models import Document
@ -18,12 +20,15 @@ def confluence_connector() -> ConfluenceConnector:
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
)
connector.load_credentials(
credentials_provider = OnyxStaticCredentialsProvider(
None,
DocumentSource.CONFLUENCE,
{
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
}
},
)
connector.set_credentials_provider(credentials_provider)
return connector

View File

@ -2,7 +2,9 @@ import os
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
@pytest.fixture
@ -11,12 +13,16 @@ def confluence_connector() -> ConfluenceConnector:
wiki_base="https://danswerai.atlassian.net",
is_cloud=True,
)
connector.load_credentials(
credentials_provider = OnyxStaticCredentialsProvider(
None,
DocumentSource.CONFLUENCE,
{
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
}
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
},
)
connector.set_credentials_provider(credentials_provider)
return connector

View File

@ -16,7 +16,7 @@ from onyx.db.engine import SYNC_DB_API
from onyx.db.search_settings import get_current_search_settings
from onyx.db.session import get_session_context_manager
from onyx.db.session import get_session_with_tenant
from onyx.db.swap_index import check_index_swap
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.db.tenant import get_all_tenant_ids
from onyx.document_index.document_index_utils import get_multipass_config
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
@ -194,7 +194,7 @@ def reset_vespa() -> None:
with get_session_context_manager() as db_session:
# swap to the correct default model
check_index_swap(db_session)
check_and_perform_index_swap(db_session)
search_settings = get_current_search_settings(db_session)
multipass_config = get_multipass_config(search_settings)
@ -289,7 +289,7 @@ def reset_vespa_multitenant() -> None:
for tenant_id in get_all_tenant_ids():
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# swap to the correct default model for each tenant
check_index_swap(db_session)
check_and_perform_index_swap(db_session)
search_settings = get_current_search_settings(db_session)
multipass_config = get_multipass_config(search_settings)

View File

@ -108,3 +108,13 @@ def admin_user() -> DATestUser:
@pytest.fixture
def reset_multitenant() -> None:
reset_all_multitenant()
def pytest_runtest_logstart(nodeid: str, location: tuple[str, int | None, str]) -> None:
print(f"\nTest start: {nodeid}")
def pytest_runtest_logfinish(
nodeid: str, location: tuple[str, int | None, str]
) -> None:
print(f"\nTest end: {nodeid}")

View File

@ -142,8 +142,12 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
selected_cc_pair = CCPairManager.get_indexing_status_by_id(
cc_pair_1.id, user_performing_action=admin_user
)
assert selected_cc_pair is not None, "cc_pair not found after indexing!"
assert selected_cc_pair.docs_indexed == 15
# used to be 15, but now
# localhost:8889/ and localhost:8889/index.html are deduped
assert selected_cc_pair.docs_indexed == 14
logger.info("Removing about.html.")
os.remove(os.path.join(website_tgt, "about.html"))
@ -160,24 +164,29 @@ def test_web_pruning(reset: None, vespa_client: vespa_fixture) -> None:
cc_pair_1.id, user_performing_action=admin_user
)
assert selected_cc_pair is not None, "cc_pair not found after pruning!"
assert selected_cc_pair.docs_indexed == 13
assert selected_cc_pair.docs_indexed == 12
# check vespa
root_id = f"http://{hostname}:{port}/"
index_id = f"http://{hostname}:{port}/index.html"
about_id = f"http://{hostname}:{port}/about.html"
courses_id = f"http://{hostname}:{port}/courses.html"
doc_ids = [index_id, about_id, courses_id]
doc_ids = [root_id, index_id, about_id, courses_id]
retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"]
retrieved_docs = {
doc["fields"]["document_id"]: doc["fields"]
for doc in retrieved_docs_dict
}
# verify index.html exists in Vespa
retrieved_doc = retrieved_docs.get(index_id)
# verify root exists in Vespa
retrieved_doc = retrieved_docs.get(root_id)
assert retrieved_doc
# verify index.html does not exist in Vespa since it is a duplicate of root
retrieved_doc = retrieved_docs.get(index_id)
assert not retrieved_doc
# verify about and courses do not exist
retrieved_doc = retrieved_docs.get(about_id)
assert not retrieved_doc

View File

@ -64,7 +64,7 @@ async def test_openai_embedding(
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
result = await embedding._embed_openai(
["test1", "test2"], "text-embedding-ada-002"
["test1", "test2"], "text-embedding-ada-002", None
)
assert result == sample_embeddings
@ -89,6 +89,7 @@ async def test_embed_text_cloud_provider() -> None:
prefix=None,
api_url=None,
api_version=None,
reduced_dimension=None,
)
assert result == [[0.1, 0.2], [0.3, 0.4]]
@ -114,6 +115,7 @@ async def test_embed_text_local_model() -> None:
prefix=None,
api_url=None,
api_version=None,
reduced_dimension=None,
)
assert result == [[0.1, 0.2], [0.3, 0.4]]
@ -157,6 +159,7 @@ async def test_rate_limit_handling() -> None:
prefix=None,
api_url=None,
api_version=None,
reduced_dimension=None,
)
@ -179,6 +182,7 @@ async def test_concurrent_embeddings() -> None:
manual_passage_prefix=None,
api_url=None,
api_version=None,
reduced_dimension=None,
)
with patch("model_server.encoders.get_embedding_model") as mock_get_model:

View File

@ -3,9 +3,7 @@ from unittest.mock import Mock
import pytest
from requests import HTTPError
from onyx.connectors.confluence.onyx_confluence import (
handle_confluence_rate_limit,
)
from onyx.connectors.confluence.utils import handle_confluence_rate_limit
@pytest.fixture
@ -50,6 +48,8 @@ def mock_confluence_call() -> Mock:
# mock_sleep.assert_called_with(int(retry_after))
# NOTE(rkuo): This tests an older version of rate limiting that is being deprecated
# and probably should go away soon.
def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
mock_confluence_call.side_effect = HTTPError(
response=Mock(status_code=500, text="Internal Server Error")

114
web/package-lock.json generated
View File

@ -52,7 +52,7 @@
"lodash": "^4.17.21",
"lucide-react": "^0.454.0",
"mdast-util-find-and-replace": "^3.0.1",
"next": "^15.0.2",
"next": "^15.2.0",
"next-themes": "^0.4.4",
"npm": "^10.8.0",
"postcss": "^8.4.31",
@ -2631,9 +2631,10 @@
}
},
"node_modules/@next/env": {
"version": "15.0.2",
"resolved": "https://registry.npmjs.org/@next/env/-/env-15.0.2.tgz",
"integrity": "sha512-c0Zr0ModK5OX7D4ZV8Jt/wqoXtitLNPwUfG9zElCZztdaZyNVnN40rDXVZ/+FGuR4CcNV5AEfM6N8f+Ener7Dg=="
"version": "15.2.0",
"resolved": "https://registry.npmjs.org/@next/env/-/env-15.2.0.tgz",
"integrity": "sha512-eMgJu1RBXxxqqnuRJQh5RozhskoNUDHBFybvi+Z+yK9qzKeG7dadhv/Vp1YooSZmCnegf7JxWuapV77necLZNA==",
"license": "MIT"
},
"node_modules/@next/eslint-plugin-next": {
"version": "14.2.3",
@ -2645,12 +2646,13 @@
}
},
"node_modules/@next/swc-darwin-arm64": {
"version": "15.0.2",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.0.2.tgz",
"integrity": "sha512-GK+8w88z+AFlmt+ondytZo2xpwlfAR8U6CRwXancHImh6EdGfHMIrTSCcx5sOSBei00GyLVL0ioo1JLKTfprgg==",
"version": "15.2.0",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-15.2.0.tgz",
"integrity": "sha512-rlp22GZwNJjFCyL7h5wz9vtpBVuCt3ZYjFWpEPBGzG712/uL1bbSkS675rVAUCRZ4hjoTJ26Q7IKhr5DfJrHDA==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"darwin"
@ -2660,12 +2662,13 @@
}
},
"node_modules/@next/swc-darwin-x64": {
"version": "15.0.2",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.0.2.tgz",
"integrity": "sha512-KUpBVxIbjzFiUZhiLIpJiBoelqzQtVZbdNNsehhUn36e2YzKHphnK8eTUW1s/4aPy5kH/UTid8IuVbaOpedhpw==",
"version": "15.2.0",
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-15.2.0.tgz",
"integrity": "sha512-DiU85EqSHogCz80+sgsx90/ecygfCSGl5P3b4XDRVZpgujBm5lp4ts7YaHru7eVTyZMjHInzKr+w0/7+qDrvMA==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"darwin"
@ -2675,12 +2678,13 @@
}
},
"node_modules/@next/swc-linux-arm64-gnu": {
"version": "15.0.2",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.0.2.tgz",
"integrity": "sha512-9J7TPEcHNAZvwxXRzOtiUvwtTD+fmuY0l7RErf8Yyc7kMpE47MIQakl+3jecmkhOoIyi/Rp+ddq7j4wG6JDskQ==",
"version": "15.2.0",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-15.2.0.tgz",
"integrity": "sha512-VnpoMaGukiNWVxeqKHwi8MN47yKGyki5q+7ql/7p/3ifuU2341i/gDwGK1rivk0pVYbdv5D8z63uu9yMw0QhpQ==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
@ -2690,12 +2694,13 @@
}
},
"node_modules/@next/swc-linux-arm64-musl": {
"version": "15.0.2",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.0.2.tgz",
"integrity": "sha512-BjH4ZSzJIoTTZRh6rG+a/Ry4SW0HlizcPorqNBixBWc3wtQtj4Sn9FnRZe22QqrPnzoaW0ctvSz4FaH4eGKMww==",
"version": "15.2.0",
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-15.2.0.tgz",
"integrity": "sha512-ka97/ssYE5nPH4Qs+8bd8RlYeNeUVBhcnsNUmFM6VWEob4jfN9FTr0NBhXVi1XEJpj3cMfgSRW+LdE3SUZbPrw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
@ -2705,12 +2710,13 @@
}
},
"node_modules/@next/swc-linux-x64-gnu": {
"version": "15.0.2",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.0.2.tgz",
"integrity": "sha512-i3U2TcHgo26sIhcwX/Rshz6avM6nizrZPvrDVDY1bXcLH1ndjbO8zuC7RoHp0NSK7wjJMPYzm7NYL1ksSKFreA==",
"version": "15.2.0",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-15.2.0.tgz",
"integrity": "sha512-zY1JduE4B3q0k2ZCE+DAF/1efjTXUsKP+VXRtrt/rJCTgDlUyyryx7aOgYXNc1d8gobys/Lof9P9ze8IyRDn7Q==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
@ -2720,12 +2726,13 @@
}
},
"node_modules/@next/swc-linux-x64-musl": {
"version": "15.0.2",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.0.2.tgz",
"integrity": "sha512-AMfZfSVOIR8fa+TXlAooByEF4OB00wqnms1sJ1v+iu8ivwvtPvnkwdzzFMpsK5jA2S9oNeeQ04egIWVb4QWmtQ==",
"version": "15.2.0",
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-15.2.0.tgz",
"integrity": "sha512-QqvLZpurBD46RhaVaVBepkVQzh8xtlUN00RlG4Iq1sBheNugamUNPuZEH1r9X1YGQo1KqAe1iiShF0acva3jHQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
@ -2735,12 +2742,13 @@
}
},
"node_modules/@next/swc-win32-arm64-msvc": {
"version": "15.0.2",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.0.2.tgz",
"integrity": "sha512-JkXysDT0/hEY47O+Hvs8PbZAeiCQVxKfGtr4GUpNAhlG2E0Mkjibuo8ryGD29Qb5a3IOnKYNoZlh/MyKd2Nbww==",
"version": "15.2.0",
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-15.2.0.tgz",
"integrity": "sha512-ODZ0r9WMyylTHAN6pLtvUtQlGXBL9voljv6ujSlcsjOxhtXPI1Ag6AhZK0SE8hEpR1374WZZ5w33ChpJd5fsjw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"win32"
@ -2750,12 +2758,13 @@
}
},
"node_modules/@next/swc-win32-x64-msvc": {
"version": "15.0.2",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.0.2.tgz",
"integrity": "sha512-foaUL0NqJY/dX0Pi/UcZm5zsmSk5MtP/gxx3xOPyREkMFN+CTjctPfu3QaqrQHinaKdPnMWPJDKt4VjDfTBe/Q==",
"version": "15.2.0",
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-15.2.0.tgz",
"integrity": "sha512-8+4Z3Z7xa13NdUuUAcpVNA6o76lNPniBd9Xbo02bwXQXnZgFvEopwY2at5+z7yHl47X9qbZpvwatZ2BRo3EdZw==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"win32"
@ -7389,11 +7398,12 @@
"integrity": "sha512-e2BR4lsJkkRlKZ/qCHPw9ZaSxc0MVUd7gtbtaB7aMvHeJVYe8sOB8DBZkP2DtISHGSku9sCK6T6cnY0CtXrOCQ=="
},
"node_modules/@swc/helpers": {
"version": "0.5.13",
"resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.13.tgz",
"integrity": "sha512-UoKGxQ3r5kYI9dALKJapMmuK+1zWM/H17Z1+iwnNmzcJRnfFuevZs375TA5rW31pu4BS4NoSy1fRsexDXfWn5w==",
"version": "0.5.15",
"resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.15.tgz",
"integrity": "sha512-JQ5TuMi45Owi4/BIMAJBoSQoOJu12oOk/gADqlcUL9JEdHB8vyjUSsxqeNXnmXHjYKMi2WcYtezGEEhqUI/E2g==",
"license": "Apache-2.0",
"dependencies": {
"tslib": "^2.4.0"
"tslib": "^2.8.0"
}
},
"node_modules/@tailwindcss/typography": {
@ -14966,13 +14976,14 @@
"integrity": "sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw=="
},
"node_modules/next": {
"version": "15.0.2",
"resolved": "https://registry.npmjs.org/next/-/next-15.0.2.tgz",
"integrity": "sha512-rxIWHcAu4gGSDmwsELXacqAPUk+j8dV/A9cDF5fsiCMpkBDYkO2AEaL1dfD+nNmDiU6QMCFN8Q30VEKapT9UHQ==",
"version": "15.2.0",
"resolved": "https://registry.npmjs.org/next/-/next-15.2.0.tgz",
"integrity": "sha512-VaiM7sZYX8KIAHBrRGSFytKknkrexNfGb8GlG6e93JqueCspuGte8i4ybn8z4ww1x3f2uzY4YpTaBEW4/hvsoQ==",
"license": "MIT",
"dependencies": {
"@next/env": "15.0.2",
"@next/env": "15.2.0",
"@swc/counter": "0.1.3",
"@swc/helpers": "0.5.13",
"@swc/helpers": "0.5.15",
"busboy": "1.6.0",
"caniuse-lite": "^1.0.30001579",
"postcss": "8.4.31",
@ -14982,25 +14993,25 @@
"next": "dist/bin/next"
},
"engines": {
"node": ">=18.18.0"
"node": "^18.18.0 || ^19.8.0 || >= 20.0.0"
},
"optionalDependencies": {
"@next/swc-darwin-arm64": "15.0.2",
"@next/swc-darwin-x64": "15.0.2",
"@next/swc-linux-arm64-gnu": "15.0.2",
"@next/swc-linux-arm64-musl": "15.0.2",
"@next/swc-linux-x64-gnu": "15.0.2",
"@next/swc-linux-x64-musl": "15.0.2",
"@next/swc-win32-arm64-msvc": "15.0.2",
"@next/swc-win32-x64-msvc": "15.0.2",
"@next/swc-darwin-arm64": "15.2.0",
"@next/swc-darwin-x64": "15.2.0",
"@next/swc-linux-arm64-gnu": "15.2.0",
"@next/swc-linux-arm64-musl": "15.2.0",
"@next/swc-linux-x64-gnu": "15.2.0",
"@next/swc-linux-x64-musl": "15.2.0",
"@next/swc-win32-arm64-msvc": "15.2.0",
"@next/swc-win32-x64-msvc": "15.2.0",
"sharp": "^0.33.5"
},
"peerDependencies": {
"@opentelemetry/api": "^1.1.0",
"@playwright/test": "^1.41.2",
"babel-plugin-react-compiler": "*",
"react": "^18.2.0 || 19.0.0-rc-02c0e824-20241028",
"react-dom": "^18.2.0 || 19.0.0-rc-02c0e824-20241028",
"react": "^18.2.0 || 19.0.0-rc-de68d2f4-20241204 || ^19.0.0",
"react-dom": "^18.2.0 || 19.0.0-rc-de68d2f4-20241204 || ^19.0.0",
"sass": "^1.3.0"
},
"peerDependenciesMeta": {
@ -20590,9 +20601,10 @@
}
},
"node_modules/tslib": {
"version": "2.6.2",
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.6.2.tgz",
"integrity": "sha512-AEYxH93jGFPn/a2iVAwW87VuUIkR1FVUKB77NwMF7nBTDkDrrT/Hpt/IrCJ0QXhW27jTBDcf5ZY7w6RiqTMw2Q=="
"version": "2.8.1",
"resolved": "https://registry.npmjs.org/tslib/-/tslib-2.8.1.tgz",
"integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==",
"license": "0BSD"
},
"node_modules/type-check": {
"version": "0.4.0",

View File

@ -55,7 +55,7 @@
"lodash": "^4.17.21",
"lucide-react": "^0.454.0",
"mdast-util-find-and-replace": "^3.0.1",
"next": "^15.0.2",
"next": "^15.2.0",
"next-themes": "^0.4.4",
"npm": "^10.8.0",
"postcss": "^8.4.31",

View File

@ -100,11 +100,6 @@ export default function LabelManagement() {
width="w-full max-w-xs"
name={`editLabelName_${label.id}`}
label="Label Name"
value={
values.editLabelId === label.id
? values.editLabelName
: label.name
}
onChange={(e) => {
setFieldValue("editLabelId", label.id);
setFieldValue("editLabelName", e.target.value);

View File

@ -163,7 +163,7 @@ export function PersonasTable() {
{popup}
{deleteModalOpen && personaToDelete && (
<ConfirmEntityModal
entityType="Persona"
entityType="Assistant"
entityName={personaToDelete.name}
onClose={closeDeleteModal}
onSubmit={handleDeletePersona}

View File

@ -52,7 +52,6 @@ export default function StarterMessagesList({
<TextFormField
name={`starter_messages.${index}.message`}
label=""
value={starterMessage.message}
onChange={(e) => handleInputChange(index, e.target.value)}
className="flex-grow"
removeLabel

View File

@ -83,6 +83,8 @@ export const SlackChannelConfigCreationForm = ({
respond_tag_only:
existingSlackChannelConfig?.channel_config?.respond_tag_only ||
false,
is_ephemeral:
existingSlackChannelConfig?.channel_config?.is_ephemeral || false,
respond_to_bots:
existingSlackChannelConfig?.channel_config?.respond_to_bots ||
false,
@ -135,6 +137,7 @@ export const SlackChannelConfigCreationForm = ({
questionmark_prefilter_enabled: Yup.boolean().required(),
respond_tag_only: Yup.boolean().required(),
respond_to_bots: Yup.boolean().required(),
is_ephemeral: Yup.boolean().required(),
show_continue_in_web_ui: Yup.boolean().required(),
enable_auto_filters: Yup.boolean().required(),
respond_member_group_list: Yup.array().of(Yup.string()).required(),

View File

@ -199,17 +199,17 @@ export function SlackChannelConfigFormFields({
<Badge variant="agent" className="bg-blue-100 text-blue-800">
Default Configuration
</Badge>
<p className="mt-2 text-sm text-gray-600">
<p className="mt-2 text-sm text-neutral-600">
This default configuration will apply across all Slack channels
the bot is added to in the Slack workspace, as well as direct
messages (DMs), unless disabled.
</p>
<div className="mt-4 p-4 bg-gray-100 rounded-md border border-gray-300">
<div className="mt-4 p-4 bg-neutral-100 rounded-md border border-neutral-300">
<CheckFormField
name="disabled"
label="Disable Default Configuration"
/>
<p className="mt-2 text-sm text-gray-600 italic">
<p className="mt-2 text-sm text-neutral-600 italic">
Warning: Disabling the default configuration means the bot
won&apos;t respond in Slack channels or DMs unless explicitly
configured for them.
@ -238,20 +238,28 @@ export function SlackChannelConfigFormFields({
/>
</div>
) : (
<Field name="channel_name">
{({ field, form }: { field: any; form: any }) => (
<SearchMultiSelectDropdown
options={channelOptions || []}
onSelect={(selected) => {
form.setFieldValue("channel_name", selected.name);
}}
initialSearchTerm={field.value}
onSearchTermChange={(term) => {
form.setFieldValue("channel_name", term);
}}
/>
)}
</Field>
<>
<Field name="channel_name">
{({ field, form }: { field: any; form: any }) => (
<SearchMultiSelectDropdown
options={channelOptions || []}
onSelect={(selected) => {
form.setFieldValue("channel_name", selected.name);
}}
initialSearchTerm={field.value}
onSearchTermChange={(term) => {
form.setFieldValue("channel_name", term);
}}
/>
)}
</Field>
<p className="mt-2 text-sm dark:text-neutral-400 text-neutral-600">
Note: This list shows public and private channels where the
bot is a member (up to 500 channels). If you don&apos;t see a
channel, make sure the bot is added to that channel in Slack
first, or type the channel name manually.
</p>
</>
)}
</>
)}
@ -589,6 +597,13 @@ export function SlackChannelConfigFormFields({
label="Respond to Bot messages"
tooltip="If not set, OnyxBot will always ignore messages from Bots"
/>
<CheckFormField
name="is_ephemeral"
label="Respond to user in a private (ephemeral) message"
tooltip="If set, OnyxBot will respond only to the user in a private (ephemeral) message. If you also
chose 'Search' Assistant above, selecting this option will make documents that are private to the user
available for their queries."
/>
<TextArrayField
name="respond_member_group_list"
@ -627,11 +642,14 @@ export function SlackChannelConfigFormFields({
Privacy Alert
</Label>
<p className="text-sm text-text-darker mb-4">
Please note that at least one of the documents accessible by
your OnyxBot is marked as private and may contain sensitive
information. These documents will be accessible to all users
of this OnyxBot. Ensure this aligns with your intended
document sharing policy.
Please note that if the private (ephemeral) response is *not
selected*, only public documents within the selected document
sets will be accessible for user queries. If the private
(ephemeral) response *is selected*, user quries can also
leverage documents that the user has already been granted
access to. Note that users will be able to share the response
with others in the channel, so please ensure that this is
aligned with your company sharing policies.
</p>
<div className="space-y-2">
<h4 className="text-sm text-text font-medium">

View File

@ -14,6 +14,7 @@ interface SlackChannelConfigCreationRequest {
answer_validity_check_enabled: boolean;
questionmark_prefilter_enabled: boolean;
respond_tag_only: boolean;
is_ephemeral: boolean;
respond_to_bots: boolean;
show_continue_in_web_ui: boolean;
respond_member_group_list: string[];
@ -45,6 +46,7 @@ const buildRequestBodyFromCreationRequest = (
channel_name: creationRequest.channel_name,
respond_tag_only: creationRequest.respond_tag_only,
respond_to_bots: creationRequest.respond_to_bots,
is_ephemeral: creationRequest.is_ephemeral,
show_continue_in_web_ui: creationRequest.show_continue_in_web_ui,
enable_auto_filters: creationRequest.enable_auto_filters,
respond_member_group_list: creationRequest.respond_member_group_list,

View File

@ -71,7 +71,7 @@ function Main() {
<p className="text-text-600">
Learn more about Unstructured{" "}
<a
href="https://unstructured.io/docs"
href="https://docs.unstructured.io/welcome"
target="_blank"
rel="noopener noreferrer"
className="text-blue-500 hover:underline font-medium"

View File

@ -108,15 +108,13 @@ export default function UpgradingPage({
>
<div>
<div>
Are you sure you want to cancel?
<br />
<br />
Cancelling will revert to the previous model and all progress will
be lost.
Are you sure you want to cancel? Cancelling will revert to the
previous model and all progress will be lost.
</div>
<div className="flex">
<Button onClick={onCancel} variant="submit">
Confirm
<div className="mt-12 gap-x-2 w-full justify-end flex">
<Button onClick={onCancel}>Confirm</Button>
<Button onClick={() => setIsCancelling(false)} variant="outline">
Cancel
</Button>
</div>
</div>
@ -141,30 +139,46 @@ export default function UpgradingPage({
</Button>
{connectors && connectors.length > 0 ? (
<>
{failedIndexingStatus && failedIndexingStatus.length > 0 && (
<FailedReIndexAttempts
failedIndexingStatuses={failedIndexingStatus}
setPopup={setPopup}
/>
)}
futureEmbeddingModel.background_reindex_enabled ? (
<>
{failedIndexingStatus && failedIndexingStatus.length > 0 && (
<FailedReIndexAttempts
failedIndexingStatuses={failedIndexingStatus}
setPopup={setPopup}
/>
)}
<Text className="my-4">
The table below shows the re-indexing progress of all existing
connectors. Once all connectors have been re-indexed
successfully, the new model will be used for all search
queries. Until then, we will use the old model so that no
downtime is necessary during this transition.
</Text>
<Text className="my-4">
The table below shows the re-indexing progress of all
existing connectors. Once all connectors have been
re-indexed successfully, the new model will be used for all
search queries. Until then, we will use the old model so
that no downtime is necessary during this transition.
</Text>
{sortedReindexingProgress ? (
<ReindexingProgressTable
reindexingProgress={sortedReindexingProgress}
/>
) : (
<ErrorCallout errorTitle="Failed to fetch reindexing progress" />
)}
</>
{sortedReindexingProgress ? (
<ReindexingProgressTable
reindexingProgress={sortedReindexingProgress}
/>
) : (
<ErrorCallout errorTitle="Failed to fetch re-indexing progress" />
)}
</>
) : (
<div className="mt-8">
<h3 className="text-lg font-semibold mb-2">
Switching Embedding Models
</h3>
<p className="mb-4 text-text-800">
You&apos;re currently switching embedding models, and
you&apos;ve selected the instant switch option. The
transition will complete shortly.
</p>
<p className="text-text-600">
The new model will be active soon.
</p>
</div>
)
) : (
<div className="mt-8 p-6 bg-background-100 border border-border-strong rounded-lg max-w-2xl">
<h3 className="text-lg font-semibold mb-2">

View File

@ -455,15 +455,15 @@ function Main({ ccPairId }: { ccPairId: number }) {
<Title>Indexing Attempts</Title>
</div>
{indexAttemptErrors && indexAttemptErrors.total_items > 0 && (
<Alert className="border-alert bg-yellow-50 my-2">
<AlertCircle className="h-4 w-4 text-yellow-700" />
<AlertTitle className="text-yellow-950 font-semibold">
<Alert className="border-alert bg-yellow-50 dark:bg-yellow-800 my-2">
<AlertCircle className="h-4 w-4 text-yellow-700 dark:text-yellow-500" />
<AlertTitle className="text-yellow-950 dark:text-yellow-200 font-semibold">
Some documents failed to index
</AlertTitle>
<AlertDescription className="text-yellow-900">
<AlertDescription className="text-yellow-900 dark:text-yellow-300">
{isResolvingErrors ? (
<span>
<span className="text-sm text-yellow-700 animate-pulse">
<span className="text-sm text-yellow-700 dark:text-yellow-400 da animate-pulse">
Resolving failures
</span>
</span>
@ -471,7 +471,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
<>
We ran into some issues while processing some documents.{" "}
<b
className="text-link cursor-pointer"
className="text-link cursor-pointer dark:text-blue-300"
onClick={() => setShowIndexAttemptErrors(true)}
>
View details.

View File

@ -193,13 +193,15 @@ export default function AddConnector({
// Check if there are no credentials
const noCredentials = credentialTemplate == null;
if (noCredentials && 1 != formStep) {
setFormStep(Math.max(1, formStep));
}
useEffect(() => {
if (noCredentials && 1 != formStep) {
setFormStep(Math.max(1, formStep));
}
if (!noCredentials && !credentialActivated && formStep != 0) {
setFormStep(Math.min(formStep, 0));
}
if (!noCredentials && !credentialActivated && formStep != 0) {
setFormStep(Math.min(formStep, 0));
}
}, [noCredentials, formStep, setFormStep]);
const convertStringToDateTime = (indexingStart: string | null) => {
return indexingStart ? new Date(indexingStart) : null;

View File

@ -33,7 +33,7 @@ export default function OAuthCallbackPage() {
const connector = pathname?.split("/")[3];
useEffect(() => {
const handleOAuthCallback = async () => {
const onFirstLoad = async () => {
// Examples
// connector (url segment)= "google-drive"
// sourceType (for looking up metadata) = "google_drive"
@ -85,10 +85,19 @@ export default function OAuthCallbackPage() {
}
setStatusMessage("Success!");
setStatusDetails(
`Your authorization with ${sourceMetadata.displayName} completed successfully.`
);
setRedirectUrl(response.redirect_on_success); // Extract the redirect URL
// set the continuation link
if (response.finalize_url) {
setRedirectUrl(response.finalize_url);
setStatusDetails(
`Your authorization with ${sourceMetadata.displayName} completed successfully. Additional steps are required to complete credential setup.`
);
} else {
setRedirectUrl(response.redirect_on_success);
setStatusDetails(
`Your authorization with ${sourceMetadata.displayName} completed successfully.`
);
}
setIsError(false);
} catch (error) {
console.error("OAuth error:", error);
@ -100,15 +109,15 @@ export default function OAuthCallbackPage() {
}
};
handleOAuthCallback();
onFirstLoad();
}, [code, state, connector]);
return (
<div className="container mx-auto py-8">
<div className="mx-auto h-screen flex flex-col">
<AdminPageTitle title={pageTitle} icon={<KeyIcon size={32} />} />
<div className="flex flex-col items-center justify-center min-h-screen">
<CardSection className="max-w-md">
<div className="flex-1 flex flex-col items-center justify-center">
<CardSection className="max-w-md w-[500px] h-[250px] p-8">
<h1 className="text-2xl font-bold mb-4">{statusMessage}</h1>
<p className="text-text-500">{statusDetails}</p>
{redirectUrl && !isError && (

View File

@ -0,0 +1,293 @@
"use client";
import { useEffect, useState } from "react";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import { AdminPageTitle } from "@/components/admin/Title";
import { Button } from "@/components/ui/button";
import Title from "@/components/ui/title";
import { KeyIcon } from "@/components/icons/icons";
import { getSourceMetadata, isValidSource } from "@/lib/sources";
import { ConfluenceAccessibleResource, ValidSources } from "@/lib/types";
import CardSection from "@/components/admin/CardSection";
import {
handleOAuthAuthorizationResponse,
handleOAuthConfluenceFinalize,
handleOAuthPrepareFinalization,
} from "@/lib/oauth_utils";
import { SelectorFormField } from "@/components/admin/connectors/Field";
import { ErrorMessage, Field, Form, Formik, useFormikContext } from "formik";
import * as Yup from "yup";
// Helper component to keep the effect logic clean:
function UpdateCloudURLOnCloudIdChange({
accessibleResources,
}: {
accessibleResources: ConfluenceAccessibleResource[];
}) {
const { values, setValues, setFieldValue } = useFormikContext<{
cloud_id: string;
cloud_name: string;
cloud_url: string;
}>();
useEffect(() => {
// Whenever cloud_id changes, find the matching resource and update cloud_url
if (values.cloud_id) {
const selectedResource = accessibleResources.find(
(resource) => resource.id === values.cloud_id
);
if (selectedResource) {
// Update multiple fields together ... somehow setting them in sequence
// doesn't work with the validator
// it may also be possible to await each setFieldValue call.
// https://github.com/jaredpalmer/formik/issues/2266
setValues((prevValues) => ({
...prevValues,
cloud_name: selectedResource.name,
cloud_url: selectedResource.url,
}));
}
}
}, [values.cloud_id, accessibleResources, setFieldValue]);
// This component doesn't render anything visible:
return null;
}
export default function OAuthFinalizePage() {
const router = useRouter();
const searchParams = useSearchParams();
const [statusMessage, setStatusMessage] = useState("Processing...");
const [statusDetails, setStatusDetails] = useState(
"Please wait while we complete the setup."
);
const [redirectUrl, setRedirectUrl] = useState<string | null>(null);
const [isError, setIsError] = useState(false);
const [isSubmitted, setIsSubmitted] = useState(false); // New state
const [pageTitle, setPageTitle] = useState(
"Finalize Authorization with Third-Party service"
);
const [accessibleResources, setAccessibleResources] = useState<
ConfluenceAccessibleResource[]
>([]);
// Extract query parameters
const credentialParam = searchParams.get("credential");
const credential = credentialParam ? parseInt(credentialParam, 10) : NaN;
const pathname = usePathname();
const connector = pathname?.split("/")[3];
useEffect(() => {
const onFirstLoad = async () => {
// Examples
// connector (url segment)= "google-drive"
// sourceType (for looking up metadata) = "google_drive"
if (isNaN(credential)) {
setStatusMessage("Improperly formed OAuth finalization request.");
setStatusDetails("Invalid or missing credential id.");
setIsError(true);
return;
}
const sourceType = connector.replaceAll("-", "_");
if (!isValidSource(sourceType)) {
setStatusMessage(
`The specified connector source type ${sourceType} does not exist.`
);
setStatusDetails(`${sourceType} is not a valid source type.`);
setIsError(true);
return;
}
const sourceMetadata = getSourceMetadata(sourceType as ValidSources);
setPageTitle(`Finalize Authorization with ${sourceMetadata.displayName}`);
setStatusMessage("Processing...");
setStatusDetails(
"Please wait while we retrieve a list of your accessible sites."
);
setIsError(false); // Ensure no error state during loading
try {
const response = await handleOAuthPrepareFinalization(
connector,
credential
);
if (!response) {
throw new Error("Empty response from OAuth server.");
}
setAccessibleResources(response.accessible_resources);
setStatusMessage("Select a Confluence site");
setStatusDetails("");
setIsError(false);
} catch (error) {
console.error("OAuth finalization error:", error);
setStatusMessage("Oops, something went wrong!");
setStatusDetails(
"An error occurred during the OAuth finalization process. Please try again."
);
setIsError(true);
}
};
onFirstLoad();
}, [credential, connector]);
useEffect(() => {}, [redirectUrl]);
return (
<div className="mx-auto h-screen flex flex-col">
<AdminPageTitle title={pageTitle} icon={<KeyIcon size={32} />} />
<div className="flex-1 flex flex-col items-center justify-center">
<CardSection className="max-w-md w-[500px] h-[250px] p-8">
<h1 className="text-2xl font-bold mb-4">{statusMessage}</h1>
<p className="text-text-500">{statusDetails}</p>
<Formik
initialValues={{
credential_id: credential,
cloud_id: "",
cloud_name: "",
cloud_url: "",
}}
validationSchema={Yup.object().shape({
credential_id: Yup.number().required(
"Credential ID is required."
),
cloud_id: Yup.string().required(
"You must select a Confluence site (id not found)."
),
cloud_name: Yup.string().required(
"You must select a Confluence site (name not found)."
),
cloud_url: Yup.string().required(
"You must select a Confluence site (url not found)."
),
})}
validateOnMount
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
try {
if (!values.cloud_id) {
throw new Error("Cloud ID is required.");
}
if (!values.cloud_name) {
throw new Error("Cloud URL is required.");
}
if (!values.cloud_url) {
throw new Error("Cloud URL is required.");
}
const response = await handleOAuthConfluenceFinalize(
values.credential_id,
values.cloud_id,
values.cloud_name,
values.cloud_url
);
formikHelpers.setSubmitting(false);
if (response) {
setRedirectUrl(response.redirect_url);
setStatusMessage("Confluence authorization finalized.");
}
setIsSubmitted(true); // Mark as submitted
} catch (error) {
console.error(error);
setStatusMessage("Error during submission.");
setStatusDetails(
"An error occurred during the submission process. Please try again."
);
setIsError(true);
formikHelpers.setSubmitting(false);
}
}}
>
{({ isSubmitting, isValid, setFieldValue }) => (
<Form>
{/* Debug info
<div className="mb-4 p-2 bg-gray-100 rounded text-xs">
<pre>
isValid: {String(isValid)}
errors: {JSON.stringify(errors, null, 2)}
values: {JSON.stringify(values, null, 2)}
</pre>
</div> */}
{/* Our helper component that reacts to changes in cloud_id */}
<UpdateCloudURLOnCloudIdChange
accessibleResources={accessibleResources}
/>
<Field type="hidden" name="cloud_name" />
<ErrorMessage
name="cloud_name"
component="div"
className="error"
/>
<Field type="hidden" name="cloud_url" />
<ErrorMessage
name="cloud_url"
component="div"
className="error"
/>
{!redirectUrl && accessibleResources.length > 0 && (
<SelectorFormField
name="cloud_id"
options={accessibleResources.map((resource) => ({
name: `${resource.name} - ${resource.url}`,
value: resource.id,
}))}
onSelect={(selectedValue) => {
const selectedResource = accessibleResources.find(
(resource) => resource.id === selectedValue
);
if (selectedResource) {
setFieldValue("cloud_id", selectedResource.id);
}
}}
/>
)}
<br />
{!redirectUrl && (
<Button
type="submit"
size="sm"
variant="submit"
disabled={!isValid || isSubmitting}
>
{isSubmitting ? "Submitting..." : "Submit"}
</Button>
)}
</Form>
)}
</Formik>
{redirectUrl && !isError && (
<div className="mt-4">
<p className="text-sm">
Authorization finalized. Click{" "}
<a href={redirectUrl} className="text-blue-500 underline">
here
</a>{" "}
to continue.
</p>
</div>
)}
</CardSection>
</div>
</div>
);
}

Some files were not shown because too many files have changed in this diff Show More