Merge branch 'main' of https://github.com/onyx-dot-app/onyx into bugfix/salesforce

# Conflicts:
#	backend/onyx/connectors/models.py
This commit is contained in:
Richard Kuo (Onyx) 2025-03-21 16:44:41 -07:00
commit 8cb4a81242
80 changed files with 3212 additions and 894 deletions

View File

@ -15,8 +15,8 @@ from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import _LOGO_FILENAME
from ee.onyx.server.enterprise_settings.store import _LOGOTYPE_FILENAME
from ee.onyx.server.enterprise_settings.store import get_logo_filename
from ee.onyx.server.enterprise_settings.store import get_logotype_filename
from ee.onyx.server.enterprise_settings.store import load_analytics_script
from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
@ -28,7 +28,7 @@ from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.utils.logger import setup_logger
admin_router = APIRouter(prefix="/admin/enterprise-settings")
@ -131,31 +131,49 @@ def put_logo(
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
def fetch_logo_or_logotype(is_logotype: bool, db_session: Session) -> Response:
def fetch_logo_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
filename = _LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME
file_io = file_store.read_file(filename, mode="b")
# NOTE: specifying "image/jpeg" here, but it still works for pngs
# TODO: do this properly
return Response(content=file_io.read(), media_type="image/jpeg")
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
except Exception:
raise HTTPException(
status_code=404,
detail=f"No {'logotype' if is_logotype else 'logo'} file found",
detail="No logo file found",
)
else:
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
except Exception:
raise HTTPException(
status_code=404,
detail="No logotype file found",
)
else:
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
@basic_router.get("/logotype")
def fetch_logotype(db_session: Session = Depends(get_session)) -> Response:
return fetch_logo_or_logotype(is_logotype=True, db_session=db_session)
return fetch_logotype_helper(db_session)
@basic_router.get("/logo")
def fetch_logo(
is_logotype: bool = False, db_session: Session = Depends(get_session)
) -> Response:
return fetch_logo_or_logotype(is_logotype=is_logotype, db_session=db_session)
if is_logotype:
return fetch_logotype_helper(db_session)
return fetch_logo_helper(db_session)
@admin_router.put("/custom-analytics-script")

View File

@ -13,6 +13,7 @@ from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import KV_CUSTOM_ANALYTICS_SCRIPT_KEY
from onyx.configs.constants import KV_ENTERPRISE_SETTINGS_KEY
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.file_store.file_store import get_default_file_store
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@ -21,8 +22,18 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_LOGO_FILENAME = "__logo__"
_LOGOTYPE_FILENAME = "__logotype__"
def load_settings() -> EnterpriseSettings:
"""Loads settings data directly from DB. This should be used primarily
for checking what is actually in the DB, aka for editing and saving back settings.
Runtime settings actually used by the application should be checked with
load_runtime_settings as defaults may be applied at runtime.
"""
dynamic_config_store = get_kv_store()
try:
settings = EnterpriseSettings(
@ -36,9 +47,24 @@ def load_settings() -> EnterpriseSettings:
def store_settings(settings: EnterpriseSettings) -> None:
"""Stores settings directly to the kv store / db."""
get_kv_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
def load_runtime_settings() -> EnterpriseSettings:
"""Loads settings from DB and applies any defaults or transformations for use
at runtime.
Should not be stored back to the DB.
"""
enterprise_settings = load_settings()
if not enterprise_settings.application_name:
enterprise_settings.application_name = ONYX_DEFAULT_APPLICATION_NAME
return enterprise_settings
_CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")
@ -60,10 +86,6 @@ def store_analytics_script(analytics_script_upload: AnalyticsScriptUpload) -> No
get_kv_store().store(KV_CUSTOM_ANALYTICS_SCRIPT_KEY, analytics_script_upload.script)
_LOGO_FILENAME = "__logo__"
_LOGOTYPE_FILENAME = "__logotype__"
def is_valid_file_type(filename: str) -> bool:
valid_extensions = (".png", ".jpg", ".jpeg")
return filename.endswith(valid_extensions)
@ -116,3 +138,11 @@ def upload_logo(
file_type=file_type,
)
return True
def get_logo_filename() -> str:
return _LOGO_FILENAME
def get_logotype_filename() -> str:
return _LOGOTYPE_FILENAME

View File

@ -271,6 +271,7 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
api_key_changed=True,
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
@ -283,7 +284,7 @@ def configure_default_api_keys(db_session: Session) -> None:
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
openai_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
@ -291,9 +292,10 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
api_key_changed=True,
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
full_provider = upsert_llm_provider(openai_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")

View File

@ -65,11 +65,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
app.state.gpu_type = gpu_type
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
try:
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
except Exception as e:
logger.warning(
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
"This is not a critical error and the model server will continue to run."
)
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.notice(f"Torch Threads: {torch.get_num_threads()}")

View File

@ -1,5 +1,6 @@
import smtplib
from datetime import datetime
from email.mime.image import MIMEImage
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
@ -13,8 +14,13 @@ from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.file import FileWithMimeType
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
HTML_EMAIL_TEMPLATE = """\
@ -97,8 +103,8 @@ HTML_EMAIL_TEMPLATE = """\
<td class="header">
<img
style="background-color: #ffffff; border-radius: 8px;"
src="https://www.onyx.app/logos/customer/onyx.png"
alt="Onyx Logo"
src="cid:logo.png"
alt="{application_name} Logo"
>
</td>
</tr>
@ -113,9 +119,8 @@ HTML_EMAIL_TEMPLATE = """\
</tr>
<tr>
<td class="footer">
© {year} Onyx. All rights reserved.
<br>
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
© {year} {application_name}. All rights reserved.
{slack_fragment}
</td>
</tr>
</table>
@ -125,17 +130,27 @@ HTML_EMAIL_TEMPLATE = """\
def build_html_email(
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
application_name: str | None,
heading: str,
message: str,
cta_text: str | None = None,
cta_link: str | None = None,
) -> str:
slack_fragment = ""
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
if cta_text and cta_link:
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
else:
cta_block = ""
return HTML_EMAIL_TEMPLATE.format(
application_name=application_name,
title=heading,
heading=heading,
message=message,
cta_block=cta_block,
slack_fragment=slack_fragment,
year=datetime.now().year,
)
@ -146,6 +161,7 @@ def send_email(
html_body: str,
text_body: str,
mail_from: str = EMAIL_FROM,
inline_png: tuple[str, bytes] | None = None,
) -> None:
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
@ -164,6 +180,12 @@ def send_email(
msg.attach(part_text)
msg.attach(part_html)
if inline_png:
img = MIMEImage(inline_png[1], _subtype="png")
img.add_header("Content-ID", inline_png[0]) # CID reference
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
msg.attach(img)
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
@ -174,8 +196,21 @@ def send_email(
def send_subscription_cancellation_email(user_email: str) -> None:
"""This is templated but isn't meaningful for whitelabeling."""
# Example usage of the reusable HTML
subject = "Your Onyx Subscription Has Been Canceled"
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Your {application_name} Subscription Has Been Canceled"
heading = "Subscription Canceled"
message = (
"<p>We're sorry to see you go.</p>"
@ -184,23 +219,48 @@ def send_subscription_cancellation_email(user_email: str) -> None:
)
cta_text = "Renew Subscription"
cta_link = "https://www.onyx.app/pricing"
html_content = build_html_email(heading, message, cta_text, cta_link)
html_content = build_html_email(
application_name,
heading,
message,
cta_text,
cta_link,
)
text_content = (
"We're sorry to see you go.\n"
"Your subscription has been canceled and will end on your next billing date.\n"
"If you change your mind, visit https://www.onyx.app/pricing"
)
send_email(user_email, subject, html_content, text_content)
send_email(
user_email,
subject,
html_content,
text_content,
inline_png=("logo.png", onyx_file.data),
)
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
subject = "Invitation to Join Onyx Organization"
onyx_file: FileWithMimeType | None = None
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Invitation to Join {application_name} Organization"
heading = "You've Been Invited!"
# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
message = f"<p>You have been invited by {current_user.email} to join an organization on {application_name}.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
@ -226,19 +286,32 @@ def send_user_email_invite(
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
html_content = build_html_email(
application_name,
heading,
message,
cta_text,
cta_link,
)
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
f"You have been invited by {current_user.email} to join an organization on {application_name}.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."
send_email(user_email, subject, html_content, text_content)
send_email(
user_email,
subject,
html_content,
text_content,
inline_png=("logo.png", onyx_file.data),
)
def send_forgot_password_email(
@ -248,14 +321,36 @@ def send_forgot_password_email(
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a forgot password email with or without fancy HTML
subject = "Onyx Forgot Password"
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"{application_name} Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if MULTI_TENANT:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email("Reset Your Password", message)
html_content = build_html_email(
application_name,
"Reset Your Password",
message,
)
text_content = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, html_content, text_content, mail_from)
send_email(
user_email,
subject,
html_content,
text_content,
mail_from,
inline_png=("logo.png", onyx_file.data),
)
def send_user_verification_email(
@ -264,11 +359,33 @@ def send_user_verification_email(
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a verification email
subject = "Onyx Email Verification"
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"{application_name} Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
message = (
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
)
html_content = build_html_email("Verify Your Email", message)
html_content = build_html_email(
application_name,
"Verify Your Email",
message,
)
text_content = f"Click the following link to verify your email address: {link}"
send_email(user_email, subject, html_content, text_content, mail_from)
send_email(
user_email,
subject,
html_content,
text_content,
mail_from,
inline_png=("logo.png", onyx_file.data),
)

View File

@ -46,7 +46,6 @@ from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
@ -420,12 +419,7 @@ def connector_permission_sync_generator_task(
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
# TODO: add some notification to the admins here
raise
source_type = cc_pair.connector.source

View File

@ -41,7 +41,6 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@ -402,12 +401,7 @@ def connector_external_group_sync_generator_task(
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
# TODO: add some notification to the admins here
raise
source_type = cc_pair.connector.source
@ -425,12 +419,9 @@ def connector_external_group_sync_generator_task(
try:
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
# TODO: add some notification to the admins here
logger.exception(
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
)
raise e

View File

@ -6,6 +6,8 @@ from sqlalchemy import and_
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.db.engine import get_db_current_time
from onyx.db.index_attempt import get_index_attempt
@ -16,7 +18,6 @@ from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
from onyx.utils.object_size_check import deep_getsizeof
logger = setup_logger()
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
@ -52,7 +53,7 @@ def save_checkpoint(
def load_checkpoint(
db_session: Session, index_attempt_id: int
db_session: Session, index_attempt_id: int, connector: BaseConnector
) -> ConnectorCheckpoint | None:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
@ -60,6 +61,8 @@ def load_checkpoint(
try:
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
if isinstance(connector, CheckpointConnector):
return connector.validate_checkpoint_json(checkpoint_data)
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
except RuntimeError:
return None
@ -71,6 +74,7 @@ def get_latest_valid_checkpoint(
search_settings_id: int,
window_start: datetime,
window_end: datetime,
connector: BaseConnector,
) -> ConnectorCheckpoint:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
@ -105,7 +109,7 @@ def get_latest_valid_checkpoint(
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return ConnectorCheckpoint.build_dummy_checkpoint()
return connector.build_dummy_checkpoint()
# assumes latest checkpoint is the furthest along. This only isn't true
# if something else has gone wrong.
@ -113,12 +117,13 @@ def get_latest_valid_checkpoint(
checkpoint_candidates[0] if checkpoint_candidates else None
)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
checkpoint = connector.build_dummy_checkpoint()
if latest_valid_checkpoint_candidate:
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
connector=connector,
)
except Exception:
logger.exception(
@ -193,7 +198,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
"""Check if the checkpoint content size exceeds the limit (200MB)"""
content_size = deep_getsizeof(checkpoint.checkpoint_content)
content_size = deep_getsizeof(checkpoint.model_dump())
if content_size > 200_000_000: # 200MB in bytes
raise ValueError(
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"

View File

@ -24,7 +24,6 @@ from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
@ -32,8 +31,11 @@ from onyx.connectors.models import TextSection
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
@ -46,8 +48,6 @@ from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.document_index.factory import get_default_document_index
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
@ -387,6 +387,7 @@ def _run_indexing(
net_doc_change = 0
document_count = 0
chunk_count = 0
index_attempt: IndexAttempt | None = None
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
@ -405,7 +406,7 @@ def _run_indexing(
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling.
if index_attempt.from_beginning:
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
db_session=db_session_temp,
@ -413,6 +414,7 @@ def _run_indexing(
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
connector=connector_runner.connector,
)
unresolved_errors = get_index_attempt_errors_for_cc_pair(
@ -596,16 +598,44 @@ def _run_indexing(
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
)
if ctx.is_primary:
update_connector_credential_pair(
if not index_attempt:
# should always be set by now
raise RuntimeError("Should never happen.")
VALIDATION_ERROR_THRESHOLD = 5
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
limit=VALIDATION_ERROR_THRESHOLD,
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
status=ConnectorCredentialPairStatus.INVALID,
)
num_validation_errors = len(
[
index_attempt
for index_attempt in recent_index_attempts
if index_attempt.error_msg
and index_attempt.error_msg.startswith(
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
)
]
)
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
logger.warning(
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
f" errors. Marking the CC Pair as invalid."
)
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
status=ConnectorCredentialPairStatus.INVALID,
)
memory_tracer.stop()
raise e

View File

@ -30,7 +30,7 @@ from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import gpu_status_request
from onyx.utils.gpu_utils import fast_gpu_status_request
from onyx.utils.logger import setup_logger
logger = setup_logger()
@ -88,7 +88,9 @@ class Answer:
rerank_settings is not None
and rerank_settings.rerank_provider_type is not None
)
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
allow_agent_reranking = (
fast_gpu_status_request(indexing=False) or using_cloud_reranking
)
# TODO: this is a hack to force the query to be used for the search tool
# this should be removed once we fully unify graph inputs (i.e.

View File

@ -33,6 +33,10 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
) # 1 day
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
# Controls whether to allow admin query history reports with:
# 1. associated user emails
# 2. anonymized user emails
# 3. no queries
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
)
@ -158,6 +162,8 @@ try:
except ValueError:
INDEX_BATCH_SIZE = 16
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
# Below are intended to match the env variables names used by the official postgres docker image
# https://hub.docker.com/_/postgres
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
@ -341,8 +347,8 @@ HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
HtmlBasedConnectorTransformLinksStrategy.STRIP,
)
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
== "true"
)
@ -414,6 +420,9 @@ EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
# Slack specific configs
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2)
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)

View File

@ -3,6 +3,10 @@ import socket
from enum import auto
from enum import Enum
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
SOURCE_TYPE = "source_type"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed
@ -40,6 +44,7 @@ DISABLED_GEN_AI_MSG = (
"You can still use Onyx as a search engine."
)
DEFAULT_PERSONA_ID = 0
DEFAULT_CC_PAIR_ID = 1

View File

@ -114,6 +114,7 @@ class ConfluenceConnector(
self.timezone_offset = timezone_offset
self._confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
self.allow_images = False
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
@ -158,6 +159,9 @@ class ConfluenceConnector(
"max_backoff_seconds": 60,
}
def set_allow_images(self, value: bool) -> None:
self.allow_images = value
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
@ -233,7 +237,9 @@ class ConfluenceConnector(
# Extract basic page information
page_id = page["id"]
page_title = page["title"]
page_url = f"{self.wiki_base}{page['_links']['webui']}"
page_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
# Get the page content
page_content = extract_text_from_confluence_html(
@ -264,6 +270,7 @@ class ConfluenceConnector(
self.confluence_client,
attachment,
page_id,
self.allow_images,
)
if result and result.text:
@ -304,13 +311,14 @@ class ConfluenceConnector(
if "version" in page and "by" in page["version"]:
author = page["version"]["by"]
display_name = author.get("displayName", "Unknown")
primary_owners.append(BasicExpertInfo(display_name=display_name))
email = author.get("email", "unknown@domain.invalid")
primary_owners.append(
BasicExpertInfo(display_name=display_name, email=email)
)
# Create the document
return Document(
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
id=page_url,
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
@ -373,6 +381,7 @@ class ConfluenceConnector(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
allow_images=self.allow_images,
)
if response is None:
continue

View File

@ -112,6 +112,7 @@ def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None,
allow_images: bool,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
@ -119,7 +120,7 @@ def process_attachment(
"""
try:
# Get the media type from the attachment metadata
media_type = attachment.get("metadata", {}).get("mediaType", "")
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
# Validate the attachment type
if not validate_attachment_filetype(attachment):
return AttachmentProcessingResult(
@ -138,7 +139,14 @@ def process_attachment(
attachment_size = attachment["extensions"]["fileSize"]
if not media_type.startswith("image/"):
if media_type.startswith("image/"):
if not allow_images:
return AttachmentProcessingResult(
text=None,
file_name=None,
error="Image downloading is not enabled",
)
else:
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {attachment_link} due to size. "
@ -294,6 +302,7 @@ def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_id: str,
allow_images: bool,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
@ -309,7 +318,7 @@ def convert_attachment_to_content(
)
return None
result = process_attachment(confluence_client, attachment, page_id)
result = process_attachment(confluence_client, attachment, page_id, allow_images)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"

View File

@ -2,6 +2,8 @@ import sys
import time
from collections.abc import Generator
from datetime import datetime
from typing import Generic
from typing import TypeVar
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
@ -19,8 +21,10 @@ logger = setup_logger()
TimeRange = tuple[datetime, datetime]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
class CheckpointOutputWrapper:
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
The connector format is easier for the connector implementor (e.g. it enforces exactly
@ -29,20 +33,20 @@ class CheckpointOutputWrapper:
"""
def __init__(self) -> None:
self.next_checkpoint: ConnectorCheckpoint | None = None
self.next_checkpoint: CT | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput,
checkpoint_connector_generator: CheckpointOutput[CT],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None],
tuple[Document | None, ConnectorFailure | None, CT | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput,
) -> CheckpointOutput:
checkpoint_connector_generator: CheckpointOutput[CT],
) -> CheckpointOutput[CT]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
@ -64,7 +68,7 @@ class CheckpointOutputWrapper:
yield None, None, self.next_checkpoint
class ConnectorRunner:
class ConnectorRunner(Generic[CT]):
"""
Handles:
- Batching
@ -85,11 +89,9 @@ class ConnectorRunner:
self.doc_batch: list[Document] = []
def run(
self, checkpoint: ConnectorCheckpoint
self, checkpoint: CT
) -> Generator[
tuple[
list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None
],
tuple[list[Document] | None, ConnectorFailure | None, CT | None],
None,
None,
]:
@ -105,9 +107,9 @@ class ConnectorRunner:
end=self.time_range[1].timestamp(),
checkpoint=checkpoint,
)
next_checkpoint: ConnectorCheckpoint | None = None
next_checkpoint: CT | None = None
# this is guaranteed to always run at least once with next_checkpoint being non-None
for document, failure, next_checkpoint in CheckpointOutputWrapper()(
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
@ -132,7 +134,7 @@ class ConnectorRunner:
)
else:
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
finished_checkpoint = self.connector.build_dummy_checkpoint()
finished_checkpoint.has_more = False
if isinstance(self.connector, PollConnector):

View File

@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
@ -184,6 +185,8 @@ def instantiate_connector(
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
return connector

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,5 @@
import io
from collections.abc import Callable
from datetime import datetime
from typing import cast
@ -13,7 +14,9 @@ from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_drive.section_extraction import get_document_sections
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
@ -76,6 +79,7 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
def _extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
allow_images: bool,
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
@ -84,6 +88,10 @@ def _extract_sections_basic(
link = file.get("webViewLink", "")
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
@ -202,12 +210,16 @@ def _extract_sections_basic(
def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
) -> Document | None:
drive_service: Callable[[], GoogleDriveService],
docs_service: Callable[[], GoogleDocsService],
allow_images: bool,
) -> Document | ConnectorFailure | None:
"""
Main entry point for converting a Google Drive file => Document object.
"""
doc_id = ""
sections: list[TextSection | ImageSection] = []
try:
# skip shortcuts or folders
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
@ -215,13 +227,11 @@ def convert_drive_item_to_document(
return None
# If it's a Google Doc, we might do advanced parsing
sections: list[TextSection | ImageSection] = []
# Try to get sections using the advanced method first
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
# get_document_sections is the advanced approach for Google Docs
doc_sections = get_document_sections(
docs_service=docs_service, doc_id=file.get("id", "")
docs_service=docs_service(), doc_id=file.get("id", "")
)
if doc_sections:
sections = cast(list[TextSection | ImageSection], doc_sections)
@ -232,7 +242,7 @@ def convert_drive_item_to_document(
# If we don't have sections yet, use the basic extraction method
if not sections:
sections = _extract_sections_basic(file, drive_service)
sections = _extract_sections_basic(file, drive_service(), allow_images)
# If we still don't have any sections, skip this file
if not sections:
@ -257,8 +267,19 @@ def convert_drive_item_to_document(
),
)
except Exception as e:
logger.error(f"Error converting file {file.get('name')}: {e}")
return None
error_str = f"Error converting file '{file.get('name')}' to Document: {e}"
logger.exception(error_str)
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=sections[0].link
if sections
else None, # TODO: see if this is the best way to get a link
),
failed_entity=None,
failure_message=error_str,
exception=e,
)
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:

View File

@ -1,17 +1,23 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from typing import Any
from datetime import timezone
from googleapiclient.discovery import Resource # type: ignore
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.models import DriveRetrievalStage
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_drive.models import RetrievedDriveFile
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.google_utils import GoogleFields
from onyx.connectors.google_utils.google_utils import ORDER_BY_KEY
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.logger import setup_logger
logger = setup_logger()
FILE_FIELDS = (
@ -31,11 +37,13 @@ def _generate_time_range_filter(
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
time_range_filter += f" and modifiedTime >= '{time_start}'"
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
time_range_filter += (
f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'"
)
if end is not None:
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
time_range_filter += f" and modifiedTime <= '{time_stop}'"
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
return time_range_filter
@ -66,9 +74,9 @@ def _get_folders_in_parent(
def _get_files_in_parent(
service: Resource,
parent_id: str,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
is_slim: bool = False,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
query += " and trashed = false"
@ -83,6 +91,7 @@ def _get_files_in_parent(
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
**({} if is_slim else {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}),
):
yield file
@ -90,30 +99,50 @@ def _get_files_in_parent(
def crawl_folders_for_files(
service: Resource,
parent_id: str,
is_slim: bool,
user_email: str,
traversed_parent_ids: set[str],
update_traversed_ids_func: Callable[[str], None],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
) -> Iterator[RetrievedDriveFile]:
"""
This function starts crawling from any folder. It is slower though.
"""
if parent_id in traversed_parent_ids:
logger.info(f"Skipping subfolder since already traversed: {parent_id}")
return
found_files = False
for file in _get_files_in_parent(
service=service,
start=start,
end=end,
parent_id=parent_id,
):
found_files = True
yield file
if found_files:
update_traversed_ids_func(parent_id)
logger.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
if parent_id not in traversed_parent_ids:
logger.info("Parent id not in traversed parent ids, getting files")
found_files = False
file = {}
try:
for file in _get_files_in_parent(
service=service,
parent_id=parent_id,
is_slim=is_slim,
start=start,
end=end,
):
found_files = True
logger.info(f"Found file: {file['name']}")
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
parent_id=parent_id,
completion_stage=DriveRetrievalStage.FOLDER_FILES,
)
except Exception as e:
logger.error(f"Error getting files in parent {parent_id}: {e}")
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
parent_id=parent_id,
completion_stage=DriveRetrievalStage.FOLDER_FILES,
error=e,
)
if found_files:
update_traversed_ids_func(parent_id)
else:
logger.info(f"Skipping subfolder files since already traversed: {parent_id}")
for subfolder in _get_folders_in_parent(
service=service,
@ -123,6 +152,8 @@ def crawl_folders_for_files(
yield from crawl_folders_for_files(
service=service,
parent_id=subfolder["id"],
is_slim=is_slim,
user_email=user_email,
traversed_parent_ids=traversed_parent_ids,
update_traversed_ids_func=update_traversed_ids_func,
start=start,
@ -133,16 +164,19 @@ def crawl_folders_for_files(
def get_files_in_shared_drive(
service: Resource,
drive_id: str,
is_slim: bool = False,
is_slim: bool,
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
kwargs = {}
if not is_slim:
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
@ -155,15 +189,13 @@ def get_files_in_shared_drive(
q=folder_query,
):
update_traversed_ids_func(file["id"])
found_folders = True
if found_folders:
update_traversed_ids_func(drive_id)
# Get all files in the shared drive
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
@ -173,16 +205,26 @@ def get_files_in_shared_drive(
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
)
**kwargs,
):
# If we found any files, mark this drive as traversed. When a user has access to a drive,
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
# empty drives.
update_traversed_ids_func(drive_id)
yield file
def get_all_files_in_my_drive(
service: Any,
service: GoogleDriveService,
update_traversed_ids_func: Callable,
is_slim: bool = False,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
kwargs = {}
if not is_slim:
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
@ -196,7 +238,7 @@ def get_all_files_in_my_drive(
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=folder_query,
):
update_traversed_ids_func(file["id"])
update_traversed_ids_func(file[GoogleFields.ID])
found_folders = True
if found_folders:
update_traversed_ids_func(get_root_folder_id(service))
@ -209,22 +251,28 @@ def get_all_files_in_my_drive(
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=False,
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
**kwargs,
)
def get_all_files_for_oauth(
service: Any,
service: GoogleDriveService,
include_files_shared_with_me: bool,
include_my_drives: bool,
# One of the above 2 should be true
include_shared_drives: bool,
is_slim: bool = False,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
kwargs = {}
if not is_slim:
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
should_get_all = (
include_shared_drives and include_my_drives and include_files_shared_with_me
)
@ -243,11 +291,13 @@ def get_all_files_for_oauth(
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=False,
corpora=corpora,
includeItemsFromAllDrives=should_get_all,
supportsAllDrives=should_get_all,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
**kwargs,
)
@ -255,4 +305,8 @@ def get_all_files_for_oauth(
def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return service.files().get(fileId="root", fields="id").execute()["id"]
return (
service.files()
.get(fileId="root", fields=GoogleFields.ID.value)
.execute()[GoogleFields.ID.value]
)

View File

@ -1,6 +1,15 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import field_serializer
from pydantic import field_validator
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.threadpool_concurrency import ThreadSafeDict
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
@ -20,3 +29,128 @@ class GDriveMimeType(str, Enum):
GoogleDriveFileType = dict[str, Any]
TOKEN_EXPIRATION_TIME = 3600 # 1 hour
# These correspond to The major stages of retrieval for google drive.
# The stages for the oauth flow are:
# get_all_files_for_oauth(),
# get_all_drive_ids(),
# get_files_in_shared_drive(),
# crawl_folders_for_files()
#
# The stages for the service account flow are roughly:
# get_all_user_emails(),
# get_all_drive_ids(),
# get_files_in_shared_drive(),
# Then for each user:
# get_files_in_my_drive()
# get_files_in_shared_drive()
# crawl_folders_for_files()
class DriveRetrievalStage(str, Enum):
START = "start"
DONE = "done"
# OAuth specific stages
OAUTH_FILES = "oauth_files"
# Service account specific stages
USER_EMAILS = "user_emails"
MY_DRIVE_FILES = "my_drive_files"
# Used for both oauth and service account flows
DRIVE_IDS = "drive_ids"
SHARED_DRIVE_FILES = "shared_drive_files"
FOLDER_FILES = "folder_files"
class StageCompletion(BaseModel):
"""
Describes the point in the retrieval+indexing process that the
connector is at. completed_until is the timestamp of the latest
file that has been retrieved or error that has been yielded.
Optional fields are used for retrieval stages that need more information
for resuming than just the timestamp of the latest file.
"""
stage: DriveRetrievalStage
completed_until: SecondsSinceUnixEpoch
completed_until_parent_id: str | None = None
# only used for shared drives
processed_drive_ids: set[str] = set()
def update(
self,
stage: DriveRetrievalStage,
completed_until: SecondsSinceUnixEpoch,
completed_until_parent_id: str | None = None,
) -> None:
self.stage = stage
self.completed_until = completed_until
self.completed_until_parent_id = completed_until_parent_id
class RetrievedDriveFile(BaseModel):
"""
Describes a file that has been retrieved from google drive.
user_email is the email of the user that the file was retrieved
by impersonating. If an error worthy of being reported is encountered,
error should be set and later propagated as a ConnectorFailure.
"""
# The stage at which this file was retrieved
completion_stage: DriveRetrievalStage
# The file that was retrieved
drive_file: GoogleDriveFileType
# The email of the user that the file was retrieved by impersonating
user_email: str
# The id of the parent folder or drive of the file
parent_id: str | None = None
# Any unexpected error that occurred while retrieving the file.
# In particular, this is not used for 403/404 errors, which are expected
# in the context of impersonating all the users to try to retrieve all
# files from all their Drives and Folders.
error: Exception | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class GoogleDriveCheckpoint(ConnectorCheckpoint):
# Checkpoint version of _retrieved_ids
retrieved_folder_and_drive_ids: set[str]
# Describes the point in the retrieval+indexing process that the
# checkpoint is at. when this is set to a given stage, the connector
# has finished yielding all values from the previous stage.
completion_stage: DriveRetrievalStage
# The latest timestamp of a file that has been retrieved per user email.
# StageCompletion is used to track the completion of each stage, but the
# timestamp part is not used for folder crawling.
completion_map: ThreadSafeDict[str, StageCompletion]
# cached version of the drive and folder ids to retrieve
drive_ids_to_retrieve: list[str] | None = None
folder_ids_to_retrieve: list[str] | None = None
# cached user emails
user_emails: list[str] | None = None
@field_serializer("completion_map")
def serialize_completion_map(
self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any
) -> dict[str, StageCompletion]:
return completion_map._dict
@field_validator("completion_map", mode="before")
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
return ThreadSafeDict(
{k: StageCompletion.model_validate(v) for k, v in v.items()}
)

View File

@ -4,6 +4,7 @@ from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from enum import Enum
from typing import Any
from googleapiclient.errors import HttpError # type: ignore
@ -16,20 +17,37 @@ logger = setup_logger()
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
# extended period of time. This is now addressed by checkpointing.
#
# NOTE: We previously tried to combat this here by adding a very
# long retry period (~20 minutes of trying, one request a minute.)
# This is no longer necessary due to checkpointing.
add_retries = retry_builder(tries=5, max_delay=10)
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
PAGE_TOKEN_KEY = "pageToken"
ORDER_BY_KEY = "orderBy"
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
class GoogleFields(str, Enum):
ID = "id"
CREATED_TIME = "createdTime"
MODIFIED_TIME = "modifiedTime"
NAME = "name"
SIZE = "size"
PARENTS = "parents"
def _execute_with_retry(request: Any) -> Any:
max_attempts = 10
max_attempts = 6
attempt = 1
while attempt < max_attempts:
# Note for reasons unknown, the Google API will sometimes return a 429
# and even after waiting the retry period, it will return another 429.
# It could be due to a few possibilities:
# 1. Other things are also requesting from the Gmail API with the same key
# 1. Other things are also requesting from the Drive/Gmail API with the same key
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
# 3. The retry-after has a maximum and we've already hit the limit for the day
# or it's something else...
@ -90,11 +108,11 @@ def execute_paginated_retrieval(
retrieval_function: The specific list function to call (e.g., service.files().list)
**kwargs: Arguments to pass to the list function
"""
next_page_token = ""
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
while next_page_token is not None:
request_kwargs = kwargs.copy()
if next_page_token:
request_kwargs["pageToken"] = next_page_token
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
try:
results = retrieval_function(**request_kwargs).execute()
@ -117,7 +135,7 @@ def execute_paginated_retrieval(
logger.exception("Error executing request:")
raise e
next_page_token = results.get("nextPageToken")
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
if list_key:
for item in results.get(list_key, []):
yield item

View File

@ -4,9 +4,11 @@ from collections.abc import Iterator
from types import TracebackType
from typing import Any
from typing import Generic
from typing import TypeAlias
from typing import TypeVar
from pydantic import BaseModel
from typing_extensions import override
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorCheckpoint
@ -19,10 +21,11 @@ SecondsSinceUnixEpoch = float
GenerateDocumentsOutput = Iterator[list[Document]]
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
class BaseConnector(abc.ABC):
class BaseConnector(abc.ABC, Generic[CT]):
REDIS_KEY_PREFIX = "da_connector_data:"
# Common image file extensions supported across connectors
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
@ -57,6 +60,14 @@ class BaseConnector(abc.ABC):
Default is a no-op (always successful).
"""
def set_allow_images(self, value: bool) -> None:
"""Implement if the underlying connector wants to skip/allow image downloading
based on the application level image analysis setting."""
def build_dummy_checkpoint(self) -> CT:
# TODO: find a way to make this work without type: ignore
return ConnectorCheckpoint(has_more=True) # type: ignore
# Large set update or reindex, generally pulling a complete state or from a savestate file
class LoadConnector(BaseConnector):
@ -74,6 +85,8 @@ class PollConnector(BaseConnector):
raise NotImplementedError
# Slim connectors can retrieve just the ids and
# permission syncing information for connected documents
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_documents(
@ -186,14 +199,17 @@ class EventConnector(BaseConnector):
raise NotImplementedError
class CheckpointConnector(BaseConnector):
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
class CheckpointConnector(BaseConnector[CT]):
@abc.abstractmethod
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
checkpoint: CT,
) -> CheckpointOutput[CT]:
"""Yields back documents or failures. Final return is the new checkpoint.
Final return can be access via either:
@ -214,3 +230,12 @@ class CheckpointConnector(BaseConnector):
```
"""
raise NotImplementedError
@override
def build_dummy_checkpoint(self) -> CT:
raise NotImplementedError
@abc.abstractmethod
def validate_checkpoint_json(self, checkpoint_json: str) -> CT:
"""Validate the checkpoint json and return the checkpoint object"""
raise NotImplementedError

View File

@ -2,6 +2,7 @@ from typing import Any
import httpx
from pydantic import BaseModel
from typing_extensions import override
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
@ -15,14 +16,18 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class MockConnectorCheckpoint(ConnectorCheckpoint):
last_document_id: str | None = None
class SingleConnectorYield(BaseModel):
documents: list[Document]
checkpoint: ConnectorCheckpoint
checkpoint: MockConnectorCheckpoint
failures: list[ConnectorFailure]
unhandled_exception: str | None = None
class MockConnector(CheckpointConnector):
class MockConnector(CheckpointConnector[MockConnectorCheckpoint]):
def __init__(
self,
mock_server_host: str,
@ -48,7 +53,7 @@ class MockConnector(CheckpointConnector):
def _get_mock_server_url(self, endpoint: str) -> str:
return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}"
def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None:
def _save_checkpoint(self, checkpoint: MockConnectorCheckpoint) -> None:
response = self.client.post(
self._get_mock_server_url("add-checkpoint"),
json=checkpoint.model_dump(mode="json"),
@ -59,8 +64,8 @@ class MockConnector(CheckpointConnector):
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
checkpoint: MockConnectorCheckpoint,
) -> CheckpointOutput[MockConnectorCheckpoint]:
if self.connector_yields is None:
raise ValueError("No connector yields configured")
@ -84,3 +89,13 @@ class MockConnector(CheckpointConnector):
yield failure
return current_yield.checkpoint
@override
def build_dummy_checkpoint(self) -> MockConnectorCheckpoint:
return MockConnectorCheckpoint(
has_more=True,
last_document_id=None,
)
def validate_checkpoint_json(self, checkpoint_json: str) -> MockConnectorCheckpoint:
return MockConnectorCheckpoint.model_validate_json(checkpoint_json)

View File

@ -1,4 +1,3 @@
import json
import sys
from datetime import datetime
from enum import Enum
@ -279,21 +278,16 @@ class IndexAttemptMetadata(BaseModel):
class ConnectorCheckpoint(BaseModel):
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
checkpoint_content: dict
has_more: bool
@classmethod
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
def __str__(self) -> str:
"""String representation of the checkpoint, with truncation for large checkpoint content."""
MAX_CHECKPOINT_CONTENT_CHARS = 1000
content_str = json.dumps(self.checkpoint_content)
content_str = self.model_dump_json()
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
return content_str
class DocumentFailure(BaseModel):

View File

@ -1,16 +1,16 @@
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses import fields
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Optional
import requests
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP
from onyx.configs.app_configs import NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rl_requests,
@ -25,6 +25,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
@ -38,8 +39,7 @@ _NOTION_CALL_TIMEOUT = 30 # 30 seconds
# TODO: Tables need to be ingested, Pages need to have their metadata ingested
@dataclass
class NotionPage:
class NotionPage(BaseModel):
"""Represents a Notion Page object"""
id: str
@ -49,17 +49,10 @@ class NotionPage:
properties: dict[str, Any]
url: str
database_name: str | None # Only applicable to the database type page (wiki)
def __init__(self, **kwargs: dict[str, Any]) -> None:
names = set([f.name for f in fields(self)])
for k, v in kwargs.items():
if k in names:
setattr(self, k, v)
database_name: str | None = None # Only applicable to the database type page (wiki)
@dataclass
class NotionBlock:
class NotionBlock(BaseModel):
"""Represents a Notion Block object"""
id: str # Used for the URL
@ -69,20 +62,13 @@ class NotionBlock:
prefix: str
@dataclass
class NotionSearchResponse:
class NotionSearchResponse(BaseModel):
"""Represents the response from the Notion Search API"""
results: list[dict[str, Any]]
next_cursor: Optional[str]
has_more: bool = False
def __init__(self, **kwargs: dict[str, Any]) -> None:
names = set([f.name for f in fields(self)])
for k, v in kwargs.items():
if k in names:
setattr(self, k, v)
class NotionConnector(LoadConnector, PollConnector):
"""Notion Page connector that reads all Notion pages
@ -95,7 +81,7 @@ class NotionConnector(LoadConnector, PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
recursive_index_enabled: bool = NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP,
recursive_index_enabled: bool = not NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP,
root_page_id: str | None = None,
) -> None:
"""Initialize with parameters."""
@ -464,23 +450,53 @@ class NotionConnector(LoadConnector, PollConnector):
page_blocks, child_page_ids = self._read_blocks(page.id)
all_child_page_ids.extend(child_page_ids)
if not page_blocks:
continue
# okay to mark here since there's no way for this to not succeed
# without a critical failure
self.indexed_pages.add(page.id)
page_title = (
self._read_page_title(page) or f"Untitled Page with ID {page.id}"
)
raw_page_title = self._read_page_title(page)
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
if not page_blocks:
if not raw_page_title:
logger.warning(
f"No blocks OR title found for page with ID '{page.id}'. Skipping."
)
continue
logger.debug(f"No blocks found for page with ID '{page.id}'")
"""
Something like:
TITLE
PROP1: PROP1_VALUE
PROP2: PROP2_VALUE
"""
text = page_title
if page.properties:
text += "\n\n" + "\n".join(
[f"{key}: {value}" for key, value in page.properties.items()]
)
sections = [
TextSection(
link=f"{page.url}",
text=text,
)
]
else:
sections = [
TextSection(
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
)
for block in page_blocks
]
yield (
Document(
id=page.id,
sections=[
TextSection(
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
)
for block in page_blocks
],
sections=cast(list[TextSection | ImageSection], sections),
source=DocumentSource.NOTION,
semantic_identifier=page_title,
doc_updated_at=datetime.fromisoformat(

View File

@ -6,6 +6,7 @@ from typing import Any
from jira import JIRA
from jira.resources import Issue
from typing_extensions import override
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
@ -15,14 +16,16 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
@ -42,121 +45,112 @@ _JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
def _paginate_jql_search(
def _perform_jql_search(
jira_client: JIRA,
jql: str,
start: int,
max_results: int,
fields: str | None = None,
) -> Iterable[Issue]:
start = 0
while True:
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
for issue in issues:
if isinstance(issue, Issue):
yield issue
else:
raise Exception(f"Found Jira object not of type Issue: {issue}")
if len(issues) < max_results:
break
start += max_results
for issue in issues:
if isinstance(issue, Issue):
yield issue
else:
raise RuntimeError(f"Found Jira object not of type Issue: {issue}")
def fetch_jira_issues_batch(
def process_jira_issue(
jira_client: JIRA,
jql: str,
batch_size: int,
issue: Issue,
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
) -> Iterable[Document]:
for issue in _paginate_jql_search(
jira_client=jira_client,
jql=jql,
max_results=batch_size,
):
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
f"Skipping {issue.key} because it has a label to skip. Found "
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
description = (
issue.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
) -> Document | None:
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
f"Skipping {issue.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
f"Skipping {issue.key} because it has a label to skip. Found "
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
return None
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
description = (
issue.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
people = set()
try:
creator = best_effort_get_field_from_issue(issue, "creator")
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
try:
assignee = best_effort_get_field_from_issue(issue, "assignee")
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
metadata_dict = {}
if priority := best_effort_get_field_from_issue(issue, "priority"):
metadata_dict["priority"] = priority.name
if status := best_effort_get_field_from_issue(issue, "status"):
metadata_dict["status"] = status.name
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
metadata_dict["resolution"] = resolution.name
if labels := best_effort_get_field_from_issue(issue, "labels"):
metadata_dict["label"] = labels
yield Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {issue.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
return None
page_url = build_jira_url(jira_client, issue.key)
people = set()
try:
creator = best_effort_get_field_from_issue(issue, "creator")
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
try:
assignee = best_effort_get_field_from_issue(issue, "assignee")
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
metadata_dict = {}
if priority := best_effort_get_field_from_issue(issue, "priority"):
metadata_dict["priority"] = priority.name
if status := best_effort_get_field_from_issue(issue, "status"):
metadata_dict["status"] = status.name
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
metadata_dict["resolution"] = resolution.name
if labels := best_effort_get_field_from_issue(issue, "labels"):
metadata_dict["labels"] = labels
return Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
metadata=metadata_dict,
)
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
class JiraConnectorCheckpoint(ConnectorCheckpoint):
offset: int | None = None
class JiraConnector(CheckpointConnector[JiraConnectorCheckpoint], SlimConnector):
def __init__(
self,
jira_base_url: str,
@ -200,33 +194,10 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _get_jql_query(self) -> str:
"""Get the JQL query based on whether a specific project is set"""
if self.jira_project:
return f"project = {self.quoted_jira_project}"
return "" # Empty string means all accessible projects
def load_from_state(self) -> GenerateDocumentsOutput:
jql = self._get_jql_query()
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
yield document_batch
def poll_source(
def _get_jql_query(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
) -> str:
"""Get the JQL query based on whether a specific project is set and time range"""
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M"
)
@ -234,25 +205,61 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
"%Y-%m-%d %H:%M"
)
base_jql = self._get_jql_query()
jql = (
f"{base_jql} AND " if base_jql else ""
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
time_jql = f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
document_batch = []
for doc in fetch_jira_issues_batch(
if self.jira_project:
base_jql = f"project = {self.quoted_jira_project}"
return f"{base_jql} AND {time_jql}"
return time_jql
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: JiraConnectorCheckpoint,
) -> CheckpointOutput[JiraConnectorCheckpoint]:
jql = self._get_jql_query(start, end)
# Get the current offset from checkpoint or start at 0
starting_offset = checkpoint.offset or 0
current_offset = starting_offset
for issue in _perform_jql_search(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
start=current_offset,
max_results=_JIRA_FULL_PAGE_SIZE,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
issue_key = issue.key
try:
if document := process_jira_issue(
jira_client=self.jira_client,
issue=issue,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
yield document
yield document_batch
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=issue_key,
document_link=build_jira_url(self.jira_client, issue_key),
),
failure_message=f"Failed to process Jira issue: {str(e)}",
exception=e,
)
current_offset += 1
# Update checkpoint
checkpoint = JiraConnectorCheckpoint(
offset=current_offset,
# if we didn't retrieve a full batch, we're done
has_more=current_offset - starting_offset == _JIRA_FULL_PAGE_SIZE,
)
return checkpoint
def retrieve_all_slim_documents(
self,
@ -260,12 +267,13 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
jql = self._get_jql_query()
jql = self._get_jql_query(start or 0, end or float("inf"))
slim_doc_batch = []
for issue in _paginate_jql_search(
for issue in _perform_jql_search(
jira_client=self.jira_client,
jql=jql,
start=0,
max_results=_JIRA_SLIM_PAGE_SIZE,
fields="key",
):
@ -334,6 +342,16 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> JiraConnectorCheckpoint:
return JiraConnectorCheckpoint.model_validate_json(checkpoint_json)
@override
def build_dummy_checkpoint(self) -> JiraConnectorCheckpoint:
return JiraConnectorCheckpoint(
has_more=True,
)
if __name__ == "__main__":
import os
@ -350,5 +368,7 @@ if __name__ == "__main__":
"jira_api_token": os.environ["JIRA_API_TOKEN"],
}
)
document_batches = connector.load_from_state()
document_batches = connector.load_from_checkpoint(
0, float("inf"), JiraConnectorCheckpoint(has_more=True)
)
print(next(document_batches))

View File

@ -10,13 +10,15 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypedDict
from pydantic import BaseModel
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from typing_extensions import override
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import SLACK_NUM_THREADS
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
@ -56,8 +58,8 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType]
class SlackCheckpointContent(TypedDict):
channel_ids: list[str]
class SlackCheckpoint(ConnectorCheckpoint):
channel_ids: list[str] | None
channel_completion_map: dict[str, str]
current_channel: ChannelType | None
seen_thread_ts: list[str]
@ -434,6 +436,16 @@ def _get_all_doc_ids(
yield slim_doc_batch
class ProcessedSlackMessage(BaseModel):
doc: Document | None
# if the message is part of a thread, this is the thread_ts
# otherwise, this is the message_ts. Either way, will be a unique identifier.
# In the future, if the message becomes a thread, then the thread_ts
# will be set to the message_ts.
thread_or_message_ts: str
failure: ConnectorFailure | None
def _process_message(
message: MessageType,
client: WebClient,
@ -442,8 +454,9 @@ def _process_message(
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
) -> ProcessedSlackMessage:
thread_ts = message.get("thread_ts")
thread_or_message_ts = thread_ts or message["ts"]
try:
# causes random failures for testing checkpointing / continue on failure
# import random
@ -459,16 +472,18 @@ def _process_message(
seen_thread_ts=seen_thread_ts,
msg_filter_func=msg_filter_func,
)
return (doc, thread_ts, None)
return ProcessedSlackMessage(
doc=doc, thread_or_message_ts=thread_or_message_ts, failure=None
)
except Exception as e:
logger.exception(f"Error processing message {message['ts']}")
return (
None,
thread_ts,
ConnectorFailure(
return ProcessedSlackMessage(
doc=None,
thread_or_message_ts=thread_or_message_ts,
failure=ConnectorFailure(
failed_document=DocumentFailure(
document_id=_build_doc_id(
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
channel_id=channel["id"], thread_ts=thread_or_message_ts
),
document_link=get_message_link(message, client, channel["id"]),
),
@ -478,8 +493,8 @@ def _process_message(
)
class SlackConnector(SlimConnector, CheckpointConnector):
MAX_WORKERS = 2
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
FAST_TIMEOUT = 1
def __init__(
self,
@ -488,12 +503,14 @@ class SlackConnector(SlimConnector, CheckpointConnector):
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
num_threads: int = SLACK_NUM_THREADS,
) -> None:
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.batch_size = batch_size
self.num_threads = num_threads
self.client: WebClient | None = None
self.fast_client: WebClient | None = None
# just used for efficiency
self.text_cleaner: SlackTextCleaner | None = None
self.user_cache: dict[str, BasicExpertInfo | None] = {}
@ -501,6 +518,10 @@ class SlackConnector(SlimConnector, CheckpointConnector):
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
bot_token = credentials["slack_bot_token"]
self.client = WebClient(token=bot_token)
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
self.fast_client = WebClient(
token=bot_token, timeout=SlackConnector.FAST_TIMEOUT
)
self.text_cleaner = SlackTextCleaner(client=self.client)
return None
@ -524,8 +545,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
checkpoint: SlackCheckpoint,
) -> CheckpointOutput[SlackCheckpoint]:
"""Rough outline:
Step 1: Get all channels, yield back Checkpoint.
@ -541,49 +562,36 @@ class SlackConnector(SlimConnector, CheckpointConnector):
if self.client is None or self.text_cleaner is None:
raise ConnectorMissingCredentialError("Slack")
checkpoint_content = cast(
SlackCheckpointContent,
(
copy.deepcopy(checkpoint.checkpoint_content)
or {
"channel_ids": None,
"channel_completion_map": {},
"current_channel": None,
"seen_thread_ts": [],
}
),
)
checkpoint = cast(SlackCheckpoint, copy.deepcopy(checkpoint))
# if this is the very first time we've called this, need to
# get all relevant channels and save them into the checkpoint
if checkpoint_content["channel_ids"] is None:
if checkpoint.channel_ids is None:
raw_channels = get_channels(self.client)
filtered_channels = filter_channels(
raw_channels, self.channels, self.channel_regex_enabled
)
checkpoint.channel_ids = [c["id"] for c in filtered_channels]
if len(filtered_channels) == 0:
checkpoint.has_more = False
return checkpoint
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
checkpoint_content["current_channel"] = filtered_channels[0]
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=True,
)
checkpoint.current_channel = filtered_channels[0]
checkpoint.has_more = True
return checkpoint
final_channel_ids = checkpoint_content["channel_ids"]
channel = checkpoint_content["current_channel"]
final_channel_ids = checkpoint.channel_ids
channel = checkpoint.current_channel
if channel is None:
raise ValueError("current_channel key not found in checkpoint")
raise ValueError("current_channel key not set in checkpoint")
channel_id = channel["id"]
if channel_id not in final_channel_ids:
raise ValueError(f"Channel {channel_id} not found in checkpoint")
oldest = str(start) if start else None
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
latest = checkpoint.channel_completion_map.get(channel_id, str(end))
seen_thread_ts = set(checkpoint.seen_thread_ts)
try:
logger.debug(
f"Getting messages for channel {channel} within range {oldest} - {latest}"
@ -594,8 +602,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
new_latest = message_batch[-1]["ts"] if message_batch else latest
# Process messages in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=SlackConnector.MAX_WORKERS) as executor:
futures: list[Future] = []
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures: list[Future[ProcessedSlackMessage]] = []
for message in message_batch:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
@ -613,46 +621,46 @@ class SlackConnector(SlimConnector, CheckpointConnector):
)
for future in as_completed(futures):
doc, thread_ts, failures = future.result()
processed_slack_message = future.result()
doc = processed_slack_message.doc
thread_or_message_ts = processed_slack_message.thread_or_message_ts
failure = processed_slack_message.failure
if doc:
# handle race conditions here since this is single
# threaded. Multi-threaded _process_message reads from this
# but since this is single threaded, we won't run into simul
# writes. At worst, we can duplicate a thread, which will be
# deduped later on.
if thread_ts not in seen_thread_ts:
if thread_or_message_ts not in seen_thread_ts:
yield doc
if thread_ts:
seen_thread_ts.add(thread_ts)
elif failures:
for failure in failures:
yield failure
assert (
thread_or_message_ts
), "found non-None doc with None thread_or_message_ts"
seen_thread_ts.add(thread_or_message_ts)
elif failure:
yield failure
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
checkpoint.seen_thread_ts = list(seen_thread_ts)
checkpoint.channel_completion_map[channel["id"]] = new_latest
if has_more_in_channel:
checkpoint_content["current_channel"] = channel
checkpoint.current_channel = channel
else:
new_channel_id = next(
(
channel_id
for channel_id in final_channel_ids
if channel_id
not in checkpoint_content["channel_completion_map"]
if channel_id not in checkpoint.channel_completion_map
),
None,
)
if new_channel_id:
new_channel = _get_channel_by_id(self.client, new_channel_id)
checkpoint_content["current_channel"] = new_channel
checkpoint.current_channel = new_channel
else:
checkpoint_content["current_channel"] = None
checkpoint.current_channel = None
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=checkpoint_content["current_channel"] is not None,
)
checkpoint.has_more = checkpoint.current_channel is not None
return checkpoint
except Exception as e:
@ -676,12 +684,12 @@ class SlackConnector(SlimConnector, CheckpointConnector):
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
"""
if self.client is None:
if self.fast_client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
try:
# 1) Validate connection to workspace
auth_response = self.client.auth_test()
auth_response = self.fast_client.auth_test()
if not auth_response.get("ok", False):
error_msg = auth_response.get(
"error", "Unknown error from Slack auth_test"
@ -689,7 +697,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
# 2) Minimal test to confirm listing channels works
test_resp = self.client.conversations_list(
test_resp = self.fast_client.conversations_list(
limit=1, types=["public_channel"]
)
if not test_resp.get("ok", False):
@ -707,29 +715,41 @@ class SlackConnector(SlimConnector, CheckpointConnector):
)
# 3) If channels are specified and regex is not enabled, verify each is accessible
if self.channels and not self.channel_regex_enabled:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,
get_public=True,
get_private=True,
)
# For quick lookups by name or ID, build a map:
accessible_channel_names = {ch["name"] for ch in accessible_channels}
accessible_channel_ids = {ch["id"] for ch in accessible_channels}
# NOTE: removed this for now since it may be too slow for large workspaces which may
# have some automations which create a lot of channels (100k+)
for user_channel in self.channels:
if (
user_channel not in accessible_channel_names
and user_channel not in accessible_channel_ids
):
raise ConnectorValidationError(
f"Channel '{user_channel}' not found or inaccessible in this workspace."
)
# if self.channels and not self.channel_regex_enabled:
# accessible_channels = get_channels(
# client=self.fast_client,
# exclude_archived=True,
# get_public=True,
# get_private=True,
# )
# # For quick lookups by name or ID, build a map:
# accessible_channel_names = {ch["name"] for ch in accessible_channels}
# accessible_channel_ids = {ch["id"] for ch in accessible_channels}
# for user_channel in self.channels:
# if (
# user_channel not in accessible_channel_names
# and user_channel not in accessible_channel_ids
# ):
# raise ConnectorValidationError(
# f"Channel '{user_channel}' not found or inaccessible in this workspace."
# )
except SlackApiError as e:
slack_error = e.response.get("error", "")
if slack_error == "missing_scope":
if slack_error == "ratelimited":
# Handle rate limiting specifically
retry_after = int(e.response.headers.get("Retry-After", 1))
logger.warning(
f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. "
"Proceeding with validation, but be aware that connector operations might be throttled."
)
# Continue validation without failing - the connector is likely valid but just rate limited
return
elif slack_error == "missing_scope":
raise InsufficientPermissionsError(
"Slack bot token lacks the necessary scope to list/access channels. "
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
@ -752,6 +772,20 @@ class SlackConnector(SlimConnector, CheckpointConnector):
f"Unexpected error during Slack settings validation: {e}"
)
@override
def build_dummy_checkpoint(self) -> SlackCheckpoint:
return SlackCheckpoint(
channel_ids=None,
channel_completion_map={},
current_channel=None,
seen_thread_ts=[],
has_more=True,
)
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> SlackCheckpoint:
return SlackCheckpoint.model_validate_json(checkpoint_json)
if __name__ == "__main__":
import os
@ -766,9 +800,11 @@ if __name__ == "__main__":
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
checkpoint = connector.build_dummy_checkpoint()
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
gen = connector.load_from_checkpoint(
one_day_ago, current, cast(SlackCheckpoint, checkpoint)
)
try:
for document_or_failure in gen:
if isinstance(document_or_failure, Document):

View File

@ -1,2 +1,4 @@
SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"
DEFAULT_PERSONA_SLACK_CHANNEL_NAME = "DEFAULT_SLACK_CHANNEL"
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX = "ConnectorValidationError:"

View File

@ -16,8 +16,8 @@ from onyx.db.models import User__UserGroup
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from shared_configs.enums import EmbeddingProvider
@ -67,7 +67,7 @@ def upsert_cloud_embedding_provider(
def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest,
db_session: Session,
) -> FullLLMProvider:
) -> LLMProviderView:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
@ -98,7 +98,7 @@ def upsert_llm_provider(
group_ids=llm_provider.groups,
db_session=db_session,
)
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
full_llm_provider = LLMProviderView.from_model(existing_llm_provider)
db_session.commit()
@ -132,6 +132,16 @@ def fetch_existing_llm_providers(
return list(db_session.scalars(stmt).all())
def fetch_existing_llm_provider(
provider_name: str, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
return provider_model
def fetch_existing_llm_providers_for_user(
db_session: Session,
user: User | None = None,
@ -177,7 +187,7 @@ def fetch_embedding_provider(
)
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712
@ -185,10 +195,10 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
@ -196,16 +206,18 @@ def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
def fetch_llm_provider_view(
db_session: Session, provider_name: str
) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def remove_embedding_provider(

View File

@ -1,7 +1,9 @@
from abc import ABC
from abc import abstractmethod
from typing import cast
from typing import IO
import puremagic
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
@ -12,6 +14,7 @@ from onyx.db.pg_file_store import delete_pgfilestore_by_file_name
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.db.pg_file_store import read_lobj
from onyx.db.pg_file_store import upsert_pgfilestore
from onyx.utils.file import FileWithMimeType
class FileStore(ABC):
@ -140,6 +143,18 @@ class PostgresBackedFileStore(FileStore):
self.db_session.rollback()
raise
def get_file_with_mime_type(self, filename: str) -> FileWithMimeType | None:
mime_type: str = "application/octet-stream"
try:
file_io = self.read_file(filename, mode="b")
file_content = file_io.read()
matches = puremagic.magic_string(file_content)
if matches:
mime_type = cast(str, matches[0].mime_type)
return FileWithMimeType(data=file_content, mime_type=mime_type)
except Exception:
return None
def get_default_file_store(db_session: Session) -> FileStore:
# The only supported file store now is the Postgres File Store

View File

@ -9,14 +9,14 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_vision_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_provider
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.models import Persona
from onyx.llm.chat_llm import DefaultMultiLLM
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderView
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
@ -62,7 +62,7 @@ def get_llms_for_persona(
)
with get_session_context_manager() as db_session:
llm_provider = fetch_provider(db_session, provider_name)
llm_provider = fetch_llm_provider_view(db_session, provider_name)
if not llm_provider:
raise ValueError("No LLM provider found")
@ -106,7 +106,7 @@ def get_default_llm_with_vision(
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
def create_vision_llm(provider: LLMProviderView, model: str) -> LLM:
"""Helper to create an LLM if the provider supports image input."""
return get_llm(
provider=provider.provider,
@ -148,7 +148,7 @@ def get_default_llm_with_vision(
provider.default_vision_model, provider.provider
):
return create_vision_llm(
FullLLMProvider.from_model(provider), provider.default_vision_model
LLMProviderView.from_model(provider), provider.default_vision_model
)
return None

View File

@ -41,6 +41,7 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import SlackBot
from onyx.db.search_settings import get_current_search_settings
from onyx.db.slack_bot import fetch_slack_bot
from onyx.db.slack_bot import fetch_slack_bots
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
@ -519,6 +520,25 @@ class SlackbotHandler:
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
# skip cases where the bot is disabled in the web UI
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
with get_session_with_current_tenant() as db_session:
slack_bot = fetch_slack_bot(
db_session=db_session, slack_bot_id=client.slack_bot_id
)
if not slack_bot:
logger.error(
f"Slack bot with ID '{client.slack_bot_id}' not found. Skipping request."
)
return False
if not slack_bot.enabled:
logger.info(
f"Slack bot with ID '{client.slack_bot_id}' is disabled. Skipping request."
)
return False
if req.type == "events_api":
# Verify channel is valid
event = cast(dict[str, Any], req.payload.get("event", {}))

View File

@ -9,9 +9,9 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_vision_provider
@ -24,9 +24,9 @@ from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.utils.logger import setup_logger
@ -49,11 +49,27 @@ def fetch_llm_options(
def test_llm_configuration(
test_llm_request: TestLLMRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
"""Test regular llm and fast llm settings"""
# the api key is sanitized if we are testing a provider already in the system
test_api_key = test_llm_request.api_key
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
test_llm_request.name, db_session
)
if existing_provider:
test_api_key = existing_provider.api_key
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
api_key=test_llm_request.api_key,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
@ -69,7 +85,7 @@ def test_llm_configuration(
fast_llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.fast_default_model_name,
api_key=test_llm_request.api_key,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
@ -119,11 +135,17 @@ def test_default_provider(
def list_llm_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[FullLLMProvider]:
return [
FullLLMProvider.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
) -> list[LLMProviderView]:
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(db_session):
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
if full_llm_provider.api_key:
full_llm_provider.api_key = (
full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:]
)
llm_provider_list.append(full_llm_provider)
return llm_provider_list
@admin_router.put("/provider")
@ -135,11 +157,11 @@ def put_llm_provider(
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
) -> LLMProviderView:
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_provider(db_session, llm_provider.name)
existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
@ -161,6 +183,11 @@ def put_llm_provider(
llm_provider.fast_default_model_name
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider.api_key_changed:
llm_provider.api_key = existing_provider.api_key
try:
return upsert_llm_provider(
llm_provider=llm_provider,
@ -234,7 +261,7 @@ def get_vision_capable_providers(
# Only include providers with at least one vision-capable model
if vision_models:
provider_dict = FullLLMProvider.from_model(provider).model_dump()
provider_dict = LLMProviderView.from_model(provider).model_dump()
provider_dict["vision_models"] = vision_models
logger.info(
f"Vision provider: {provider.provider} with models: {vision_models}"

View File

@ -12,6 +12,7 @@ if TYPE_CHECKING:
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
provider: str
api_key: str | None = None
api_base: str | None = None
@ -76,16 +77,19 @@ class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
model_names: list[str] | None = None
api_key_changed: bool = False
class FullLLMProvider(LLMProvider):
class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_names: list[str]
@classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@ -111,7 +115,7 @@ class FullLLMProvider(LLMProvider):
)
class VisionProviderResponse(FullLLMProvider):
class VisionProviderResponse(LLMProviderView):
"""Response model for vision providers endpoint, including vision-specific fields."""
vision_models: list[str]

View File

@ -32,10 +32,14 @@ from onyx.server.manage.models import SlackChannelConfig
from onyx.server.manage.models import SlackChannelConfigCreationRequest
from onyx.server.manage.validate_tokens import validate_app_token
from onyx.server.manage.validate_tokens import validate_bot_token
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/manage")
@ -376,7 +380,7 @@ def get_all_channels_from_slack_api(
status_code=404, detail="Bot token not found for the given bot ID"
)
client = WebClient(token=tokens["bot_token"])
client = WebClient(token=tokens["bot_token"], timeout=1)
all_channels = []
next_cursor = None
current_page = 0
@ -431,6 +435,7 @@ def get_all_channels_from_slack_api(
except SlackApiError as e:
# Handle rate limiting or other API errors
logger.exception("Error fetching channels from Slack API")
raise HTTPException(
status_code=500,
detail=f"Error fetching channels from Slack API: {str(e)}",

View File

@ -351,9 +351,11 @@ def remove_invited_user(
user_emails = get_invited_users()
remaining_users = [user for user in user_emails if user != user_email.user_email]
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
if MULTI_TENANT:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
number_of_invited_users = write_invited_users(remaining_users)
try:

View File

@ -0,0 +1,89 @@
import io
from PIL import Image
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
from onyx.db.engine import get_session_with_shared_schema
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.utils.file import FileWithMimeType
from onyx.utils.file import OnyxStaticFileManager
from onyx.utils.variable_functionality import (
fetch_ee_implementation_or_noop,
)
class OnyxRuntime:
"""Used by the application to get the final runtime value of a setting.
Rationale: Settings and overrides may be persisted in multiple places, including the
DB, Redis, env vars, and default constants, etc. The logic to present a final
setting to the application should be centralized and in one place.
Example: To get the logo for the application, one must check the DB for an override,
use the override if present, fall back to the filesystem if not present, and worry
about enterprise or not enterprise.
"""
@staticmethod
def _get_with_static_fallback(
db_filename: str | None, static_filename: str
) -> FileWithMimeType:
onyx_file: FileWithMimeType | None = None
if db_filename:
with get_session_with_shared_schema() as db_session:
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(db_filename)
if not onyx_file:
onyx_file = OnyxStaticFileManager.get_static(static_filename)
if not onyx_file:
raise RuntimeError(
f"Resource not found: db={db_filename} static={static_filename}"
)
return onyx_file
@staticmethod
def get_logo() -> FileWithMimeType:
STATIC_FILENAME = "static/images/logo.png"
db_filename: str | None = fetch_ee_implementation_or_noop(
"onyx.server.enterprise_settings.store", "get_logo_filename", None
)
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)
@staticmethod
def get_emailable_logo() -> FileWithMimeType:
onyx_file = OnyxRuntime.get_logo()
# check dimensions and resize downwards if necessary or if not PNG
image = Image.open(io.BytesIO(onyx_file.data))
if (
image.size[0] > ONYX_EMAILABLE_LOGO_MAX_DIM
or image.size[1] > ONYX_EMAILABLE_LOGO_MAX_DIM
or image.format != "PNG"
):
image.thumbnail(
(ONYX_EMAILABLE_LOGO_MAX_DIM, ONYX_EMAILABLE_LOGO_MAX_DIM),
Image.LANCZOS,
) # maintains aspect ratio
output_buffer = io.BytesIO()
image.save(output_buffer, format="PNG")
onyx_file = FileWithMimeType(
data=output_buffer.getvalue(), mime_type="image/png"
)
return onyx_file
@staticmethod
def get_logotype() -> FileWithMimeType:
STATIC_FILENAME = "static/images/logotype.png"
db_filename: str | None = fetch_ee_implementation_or_noop(
"onyx.server.enterprise_settings.store", "get_logotype_filename", None
)
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)

View File

@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None:
groups=[],
display_model_names=OPEN_AI_MODEL_NAMES,
model_names=OPEN_AI_MODEL_NAMES,
api_key_changed=True,
)
new_llm_provider = upsert_llm_provider(
llm_provider=model_req, db_session=db_session
@ -323,7 +324,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
logger.info(
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
)
gpu_available = gpu_status_request()
gpu_available = gpu_status_request(indexing=True)
logger.info(f"GPU available: {gpu_available}")
current_settings = get_current_search_settings(db_session)

View File

@ -21,7 +21,6 @@ def build_tool_message(
)
# TODO: does this NEED to be BaseModel__v1?
class ToolCallSummary(BaseModel):
tool_call_request: AIMessage
tool_call_result: ToolMessage

View File

@ -0,0 +1,36 @@
from typing import cast
import puremagic
from pydantic import BaseModel
from onyx.utils.logger import setup_logger
logger = setup_logger()
class FileWithMimeType(BaseModel):
data: bytes
mime_type: str
class OnyxStaticFileManager:
"""Retrieve static resources with this class. Currently, these should all be located
in the static directory ... e.g. static/images/logo.png"""
@staticmethod
def get_static(filename: str) -> FileWithMimeType | None:
try:
mime_type: str = "application/octet-stream"
with open(filename, "rb") as f:
file_content = f.read()
matches = puremagic.magic_string(file_content)
if matches:
mime_type = cast(str, matches[0].mime_type)
except (OSError, FileNotFoundError, PermissionError) as e:
logger.error(f"Failed to read file {filename}: {e}")
return None
except Exception as e:
logger.error(f"Unexpected exception reading file {filename}: {e}")
return None
return FileWithMimeType(data=file_content, mime_type=mime_type)

View File

@ -1,3 +1,5 @@
from functools import lru_cache
import requests
from retry import retry
@ -10,8 +12,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool = True) -> bool:
def _get_gpu_status_from_model_server(indexing: bool) -> bool:
if indexing:
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
else:
@ -28,3 +29,14 @@ def gpu_status_request(indexing: bool = True) -> bool:
except requests.RequestException as e:
logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}")
raise # Re-raise exception to trigger a retry
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool) -> bool:
return _get_gpu_status_from_model_server(indexing)
@lru_cache(maxsize=1)
def fast_gpu_status_request(indexing: bool) -> bool:
"""For use in sync flows, where we don't want to retry / we want to cache this."""
return gpu_status_request(indexing=indexing)

View File

@ -0,0 +1,13 @@
from collections.abc import Callable
from functools import lru_cache
from typing import TypeVar
R = TypeVar("R")
def lazy_eval(func: Callable[[], R]) -> Callable[[], R]:
@lru_cache(maxsize=1)
def lazy_func() -> R:
return func()
return lazy_func

View File

@ -1,18 +1,148 @@
import collections.abc
import contextvars
import copy
import threading
import uuid
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import MutableMapping
from concurrent.futures import as_completed
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
from typing import Any
from typing import Generic
from typing import overload
from typing import TypeVar
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from onyx.utils.logger import setup_logger
logger = setup_logger()
R = TypeVar("R")
KT = TypeVar("KT") # Key type
VT = TypeVar("VT") # Value type
_T = TypeVar("_T") # Default type
class ThreadSafeDict(MutableMapping[KT, VT]):
"""
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
Implements the MutableMapping interface to provide a complete dictionary-like interface.
Example usage:
# Create a thread-safe dictionary
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
# Basic operations (atomic)
safe_dict["key"] = 1
value = safe_dict["key"]
del safe_dict["key"]
# Bulk operations (atomic)
safe_dict.update({"key1": 1, "key2": 2})
"""
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
self._dict: dict[KT, VT] = input_dict or {}
self.lock = threading.Lock()
def __getitem__(self, key: KT) -> VT:
with self.lock:
return self._dict[key]
def __setitem__(self, key: KT, value: VT) -> None:
with self.lock:
self._dict[key] = value
def __delitem__(self, key: KT) -> None:
with self.lock:
del self._dict[key]
def __iter__(self) -> Iterator[KT]:
# Return a snapshot of keys to avoid potential modification during iteration
with self.lock:
return iter(list(self._dict.keys()))
def __len__(self) -> int:
with self.lock:
return len(self._dict)
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
cls.validate, handler(dict[KT, VT])
)
@classmethod
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
if isinstance(v, dict):
return ThreadSafeDict(v)
return v
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
return ThreadSafeDict(copy.deepcopy(self._dict))
def clear(self) -> None:
"""Remove all items from the dictionary atomically."""
with self.lock:
self._dict.clear()
def copy(self) -> dict[KT, VT]:
"""Return a shallow copy of the dictionary atomically."""
with self.lock:
return self._dict.copy()
@overload
def get(self, key: KT) -> VT | None:
...
@overload
def get(self, key: KT, default: VT | _T) -> VT | _T:
...
def get(self, key: KT, default: Any = None) -> Any:
"""Get a value with a default, atomically."""
with self.lock:
return self._dict.get(key, default)
def pop(self, key: KT, default: Any = None) -> Any:
"""Remove and return a value with optional default, atomically."""
with self.lock:
if default is None:
return self._dict.pop(key)
return self._dict.pop(key, default)
def setdefault(self, key: KT, default: VT) -> VT:
"""Set a default value if key is missing, atomically."""
with self.lock:
return self._dict.setdefault(key, default)
def update(self, *args: Any, **kwargs: VT) -> None:
"""Update the dictionary atomically from another mapping or from kwargs."""
with self.lock:
self._dict.update(*args, **kwargs)
def items(self) -> collections.abc.ItemsView[KT, VT]:
"""Return a view of (key, value) pairs atomically."""
with self.lock:
return collections.abc.ItemsView(self)
def keys(self) -> collections.abc.KeysView[KT]:
"""Return a view of keys atomically."""
with self.lock:
return collections.abc.KeysView(self)
def values(self) -> collections.abc.ValuesView[VT]:
"""Return a view of values atomically."""
with self.lock:
return collections.abc.ValuesView(self)
def run_functions_tuples_in_parallel(
@ -190,3 +320,27 @@ def wait_on_background(task: TimeoutThread[R]) -> R:
raise task.exception
return task.result
def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]:
return ind, next(g, None)
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index: dict[Future[tuple[int, R | None]], int] = {
executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens)
}
next_ind = len(gens)
while future_to_index:
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
for future in done:
ind, result = future.result()
if result is not None:
yield result
future_to_index[
executor.submit(_next_or_none, ind, gens[ind])
] = next_ind
next_ind += 1
del future_to_index[future]

View File

@ -52,6 +52,7 @@ openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5
psycopg2-binary==2.9.9
puremagic==1.28
pyairtable==3.0.1
pycryptodome==3.19.1
pydantic==2.8.2

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

View File

@ -1,5 +1,6 @@
import os
import time
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
@ -7,15 +8,16 @@ import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.utils import AttachmentProcessingResult
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
from onyx.connectors.models import Document
@pytest.fixture
def confluence_connector() -> ConfluenceConnector:
def confluence_connector(space: str) -> ConfluenceConnector:
connector = ConfluenceConnector(
wiki_base=os.environ["CONFLUENCE_TEST_SPACE_URL"],
space=os.environ["CONFLUENCE_TEST_SPACE"],
space=space,
is_cloud=os.environ.get("CONFLUENCE_IS_CLOUD", "true").lower() == "true",
page_id=os.environ.get("CONFLUENCE_TEST_PAGE_ID", ""),
)
@ -32,14 +34,15 @@ def confluence_connector() -> ConfluenceConnector:
return connector
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@pytest.mark.skip(reason="Skipping this test")
def test_confluence_connector_basic(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
confluence_connector.set_allow_images(False)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
@ -50,15 +53,14 @@ def test_confluence_connector_basic(
page_within_a_page_doc: Document | None = None
page_doc: Document | None = None
txt_doc: Document | None = None
for doc in doc_batch:
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
page_doc = doc
elif ".txt" in doc.semantic_identifier:
txt_doc = doc
elif doc.semantic_identifier == "Page Within A Page":
page_within_a_page_doc = doc
else:
pass
assert page_within_a_page_doc is not None
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
@ -79,7 +81,7 @@ def test_confluence_connector_basic(
assert page_doc.metadata["labels"] == ["testlabel"]
assert page_doc.primary_owners
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
assert len(page_doc.sections) == 1
assert len(page_doc.sections) == 2 # page text + attachment text
page_section = page_doc.sections[0]
assert page_section.text == "test123 " + page_within_a_page_text
@ -88,13 +90,65 @@ def test_confluence_connector_basic(
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
)
assert txt_doc is not None
assert txt_doc.semantic_identifier == "small-file.txt"
assert len(txt_doc.sections) == 1
assert txt_doc.sections[0].text == "small"
assert txt_doc.primary_owners
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
assert (
txt_doc.sections[0].link
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
text_attachment_section = page_doc.sections[1]
assert text_attachment_section.text == "small"
assert text_attachment_section.link
assert text_attachment_section.link.endswith("small-file.txt")
@pytest.mark.parametrize("space", ["MI"])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_confluence_connector_skip_images(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
confluence_connector.set_allow_images(False)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 8
assert sum(len(doc.sections) for doc in doc_batch) == 8
def mock_process_image_attachment(
*args: Any, **kwargs: Any
) -> AttachmentProcessingResult:
"""We need this mock to bypass DB access happening in the connector. Which shouldn't
be done as a rule to begin with, but life is not perfect. Fix it later"""
return AttachmentProcessingResult(
text="Hi_text",
file_name="Hi_filename",
error=None,
)
@pytest.mark.parametrize("space", ["MI"])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@patch(
"onyx.connectors.confluence.utils._process_image_attachment",
side_effect=mock_process_image_attachment,
)
def test_confluence_connector_allow_images(
mock_get_api_key: MagicMock,
mock_process_image_attachment: MagicMock,
confluence_connector: ConfluenceConnector,
) -> None:
confluence_connector.set_allow_images(True)
doc_batch_generator = confluence_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 8
assert sum(len(doc.sections) for doc in doc_batch) == 12

View File

@ -1,7 +1,10 @@
import time
from collections.abc import Sequence
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
ALL_FILES = list(range(0, 60))
SHARED_DRIVE_FILES = list(range(20, 25))
@ -21,6 +24,7 @@ FOLDER_2_FILE_IDS = list(range(45, 50))
FOLDER_2_1_FILE_IDS = list(range(50, 55))
FOLDER_2_2_FILE_IDS = list(range(55, 60))
SECTIONS_FILE_IDS = [61]
FOLDER_3_FILE_IDS = list(range(62, 65))
PUBLIC_FOLDER_RANGE = FOLDER_1_2_FILE_IDS
PUBLIC_FILE_IDS = list(range(55, 57))
@ -54,6 +58,8 @@ SECTIONS_FOLDER_URL = (
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
)
SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA"
ADMIN_EMAIL = "admin@onyx-test.com"
TEST_USER_1_EMAIL = "test_user_1@onyx-test.com"
TEST_USER_2_EMAIL = "test_user_2@onyx-test.com"
@ -133,17 +139,19 @@ def filter_invalid_prefixes(names: set[str]) -> set[str]:
return {name for name in names if name.startswith(_VALID_PREFIX)}
def print_discrepencies(
def print_discrepancies(
expected: set[str],
retrieved: set[str],
) -> None:
if expected != retrieved:
print(expected)
print(retrieved)
expected_list = sorted(expected)
retrieved_list = sorted(retrieved)
print(expected_list)
print(retrieved_list)
print("Extra:")
print(retrieved - expected)
print(sorted(retrieved - expected))
print("Missing:")
print(expected - retrieved)
print(sorted(expected - retrieved))
def _get_expected_file_content(file_id: int) -> str:
@ -164,6 +172,8 @@ def assert_retrieved_docs_match_expected(
_get_expected_file_content(file_id) for file_id in expected_file_ids
}
retrieved_docs.sort(key=lambda x: x.semantic_identifier)
for doc in retrieved_docs:
print(f"doc.semantic_identifier: {doc.semantic_identifier}")
@ -190,15 +200,23 @@ def assert_retrieved_docs_match_expected(
)
# Check file names
print_discrepencies(
print_discrepancies(
expected=expected_file_names,
retrieved=valid_retrieved_file_names,
)
assert expected_file_names == valid_retrieved_file_names
# Check file texts
print_discrepencies(
print_discrepancies(
expected=expected_file_texts,
retrieved=valid_retrieved_texts,
)
assert expected_file_texts == valid_retrieved_texts
def load_all_docs(connector: GoogleDriveConnector) -> list[Document]:
return load_all_docs_from_checkpoint_connector(
connector,
0,
time.time(),
)

View File

@ -1,10 +1,8 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
@ -23,6 +21,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
@ -47,9 +46,7 @@ def test_include_all(
my_drive_emails=None,
shared_drive_urls=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should get everything in shared and admin's My Drive with oauth
expected_file_ids = (
@ -89,9 +86,7 @@ def test_include_shared_drives_only(
my_drive_emails=None,
shared_drive_urls=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should only get shared drives
expected_file_ids = (
@ -129,9 +124,7 @@ def test_include_my_drives_only(
my_drive_emails=None,
shared_drive_urls=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should only get primary_admins My Drive because we are impersonating them
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
@ -160,9 +153,7 @@ def test_drive_one_only(
my_drive_emails=None,
shared_drive_urls=",".join([str(url) for url in drive_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
SHARED_DRIVE_1_FILE_IDS
@ -196,9 +187,7 @@ def test_folder_and_shared_drive(
my_drive_emails=None,
shared_drive_urls=",".join([str(url) for url in drive_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
SHARED_DRIVE_1_FILE_IDS
@ -243,9 +232,7 @@ def test_folders_only(
my_drive_emails=None,
shared_drive_urls=",".join([str(url) for url in shared_drive_urls]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
FOLDER_1_1_FILE_IDS
@ -281,9 +268,7 @@ def test_personal_folders_only(
my_drive_emails=None,
shared_drive_urls=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected(

View File

@ -1,11 +1,10 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_URL
@ -37,9 +36,7 @@ def test_google_drive_sections(
my_drive_emails=None,
)
for connector in [oauth_connector, service_acct_connector]:
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Verify we got the 1 doc with sections
assert len(retrieved_docs) == 1

View File

@ -1,10 +1,8 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
@ -23,6 +21,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
@ -52,9 +51,7 @@ def test_include_all(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should get everything
expected_file_ids = (
@ -97,9 +94,7 @@ def test_include_shared_drives_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should only get shared drives
expected_file_ids = (
@ -137,9 +132,7 @@ def test_include_my_drives_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should only get everyone's My Drives
expected_file_ids = (
@ -174,9 +167,7 @@ def test_drive_one_only(
shared_drive_urls=",".join([str(url) for url in urls]),
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# We ignore shared_drive_urls if include_shared_drives is False
expected_file_ids = (
@ -211,9 +202,7 @@ def test_folder_and_shared_drive(
shared_folder_urls=",".join([str(url) for url in folder_urls]),
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# Should get everything except for the top level files in drive 2
expected_file_ids = (
@ -259,9 +248,7 @@ def test_folders_only(
shared_folder_urls=",".join([str(url) for url in folder_urls]),
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
FOLDER_1_1_FILE_IDS
@ -298,9 +285,7 @@ def test_specific_emails(
shared_drive_urls=None,
my_drive_emails=",".join([str(email) for email in my_drive_emails]),
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS
assert_retrieved_docs_match_expected(
@ -330,9 +315,7 @@ def get_specific_folders_in_my_drive(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected(

View File

@ -22,7 +22,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_I
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import print_discrepencies
from tests.daily.connectors.google_drive.consts_and_utils import print_discrepancies
from tests.daily.connectors.google_drive.consts_and_utils import PUBLIC_RANGE
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
@ -83,7 +83,7 @@ def assert_correct_access_for_user(
expected_file_names = {file_name_template.format(i) for i in all_accessible_ids}
filtered_retrieved_file_names = filter_invalid_prefixes(retrieved_file_names)
print_discrepencies(expected_file_names, filtered_retrieved_file_names)
print_discrepancies(expected_file_names, filtered_retrieved_file_names)
assert expected_file_names == filtered_retrieved_file_names
@ -175,7 +175,7 @@ def test_all_permissions(
# Should get everything
filtered_retrieved_file_names = filter_invalid_prefixes(found_file_names)
print_discrepencies(expected_file_names, filtered_retrieved_file_names)
print_discrepancies(expected_file_names, filtered_retrieved_file_names)
assert expected_file_names == filtered_retrieved_file_names
group_map = get_group_map(google_drive_connector)

View File

@ -1,10 +1,8 @@
import time
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.models import Document
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_retrieved_docs_match_expected,
@ -14,6 +12,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
from tests.daily.connectors.google_drive.consts_and_utils import load_all_docs
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import TEST_USER_1_FILE_IDS
@ -37,9 +36,7 @@ def test_all(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
# These are the files from my drive
@ -77,9 +74,7 @@ def test_shared_drives_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
# These are the files from shared drives
@ -112,9 +107,7 @@ def test_shared_with_me_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
# These are the files shared with me from admin
@ -145,9 +138,7 @@ def test_my_drive_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
# These are the files from my drive
expected_file_ids = TEST_USER_1_FILE_IDS
@ -175,9 +166,7 @@ def test_shared_my_drive_folder(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = (
# this is a folder from admin's drive that is shared with me
@ -207,9 +196,7 @@ def test_shared_drive_folder(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs: list[Document] = []
for doc_batch in connector.poll_source(0, time.time()):
retrieved_docs.extend(doc_batch)
retrieved_docs = load_all_docs(connector)
expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS
assert_retrieved_docs_match_expected(

View File

@ -5,6 +5,7 @@ import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.onyx_jira.connector import JiraConnector
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
@pytest.fixture
@ -24,15 +25,13 @@ def jira_connector() -> JiraConnector:
def test_jira_connector_basic(jira_connector: JiraConnector) -> None:
doc_batch_generator = jira_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 1
doc = doc_batch[0]
docs = load_all_docs_from_checkpoint_connector(
connector=jira_connector,
start=0,
end=time.time(),
)
assert len(docs) == 1
doc = docs[0]
assert doc.id == "https://danswerai.atlassian.net/browse/AS-2"
assert doc.semantic_identifier == "AS-2: test123small"

View File

@ -0,0 +1,70 @@
from typing import cast
from typing import TypeVar
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
_ITERATION_LIMIT = 100_000
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def load_all_docs_from_checkpoint_connector(
connector: CheckpointConnector[CT],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[Document]:
num_iterations = 0
checkpoint = cast(CT, connector.build_dummy_checkpoint())
documents: list[Document] = []
while checkpoint.has_more:
doc_batch_generator = CheckpointOutputWrapper[CT]()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
raise RuntimeError(f"Failed to load documents: {failure}")
if document is not None:
documents.append(document)
if next_checkpoint is not None:
checkpoint = next_checkpoint
num_iterations += 1
if num_iterations > _ITERATION_LIMIT:
raise RuntimeError("Too many iterations. Infinite loop?")
return documents
def load_everything_from_checkpoint_connector(
connector: CheckpointConnector[CT],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[Document | ConnectorFailure]:
"""Like load_all_docs_from_checkpoint_connector but returns both documents and failures"""
num_iterations = 0
checkpoint = connector.build_dummy_checkpoint()
outputs: list[Document | ConnectorFailure] = []
while checkpoint.has_more:
doc_batch_generator = CheckpointOutputWrapper[CT]()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
outputs.append(failure)
if document is not None:
outputs.append(document)
if next_checkpoint is not None:
checkpoint = next_checkpoint
num_iterations += 1
if num_iterations > _ITERATION_LIMIT:
raise RuntimeError("Too many iterations. Infinite loop?")
return outputs

View File

@ -28,6 +28,14 @@ The idea is that each test can use the manager class to create (.create()) a "te
pytest -s tests/integration/tests/path_to/test_file.py::test_function_name
```
Running some single tests require the `mock_connector_server` container to be running. If the above doesn't work,
navigate to `backend/tests/integration/mock_services` and run
```sh
docker compose -f docker-compose.mock-it-services.yml -p mock-it-services-stack up -d
```
You will have to modify the networks section of the docker-compose file to `<your stack name>_default` if you brought up the standard
onyx services with a name different from the default `onyx-stack`.
## Guidelines for Writing Integration Tests
- As authentication is currently required for all tests, each test should start by creating a user.

View File

@ -3,8 +3,8 @@ from uuid import uuid4
import requests
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
@ -39,6 +39,7 @@ class LLMProviderManager:
groups=groups or [],
display_model_names=None,
model_names=None,
api_key_changed=True,
)
llm_response = requests.put(
@ -90,7 +91,7 @@ class LLMProviderManager:
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[FullLLMProvider]:
) -> list[LLMProviderView]:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=user_performing_action.headers
@ -98,7 +99,7 @@ class LLMProviderManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return [FullLLMProvider(**ug) for ug in response.json()]
return [LLMProviderView(**ug) for ug in response.json()]
@staticmethod
def verify(
@ -111,18 +112,19 @@ class LLMProviderManager:
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
f"User group {llm_provider.id} found but should be deleted"
f"LLM Provider {llm_provider.id} found but should be deleted"
)
fetched_llm_groups = set(fetched_llm_provider.groups)
llm_provider_groups = set(llm_provider.groups)
# NOTE: returned api keys are sanitized and should not match
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.api_key == fetched_llm_provider.api_key
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and llm_provider.is_public == fetched_llm_provider.is_public
):
return
if not verify_deleted:
raise ValueError(f"User group {llm_provider.id} not found")
raise ValueError(f"LLM Provider {llm_provider.id} not found")

View File

@ -7,7 +7,7 @@ import httpx
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import InputType
@ -54,9 +54,9 @@ def test_mock_connector_basic_flow(
json=[
{
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],
@ -128,9 +128,9 @@ def test_mock_connector_with_failures(
json=[
{
"documents": [doc1.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [doc2_failure.model_dump(mode="json")],
}
],
@ -208,9 +208,9 @@ def test_mock_connector_failure_recovery(
json=[
{
"documents": [doc1.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [
doc2_failure.model_dump(mode="json"),
ConnectorFailure(
@ -292,9 +292,9 @@ def test_mock_connector_failure_recovery(
doc1.model_dump(mode="json"),
doc2.model_dump(mode="json"),
],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
).model_dump(mode="json"),
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],
@ -372,23 +372,23 @@ def test_mock_connector_checkpoint_recovery(
json=[
{
"documents": [doc.model_dump(mode="json") for doc in docs_batch_1],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=True
"checkpoint": MockConnectorCheckpoint(
has_more=True, last_document_id=docs_batch_1[-1].id
).model_dump(mode="json"),
"failures": [],
},
{
"documents": [doc2.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=True
"checkpoint": MockConnectorCheckpoint(
has_more=True, last_document_id=doc2.id
).model_dump(mode="json"),
"failures": [],
},
{
"documents": [],
# should never hit this, unhandled exception happens first
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
"checkpoint": MockConnectorCheckpoint(
has_more=False, last_document_id=doc2.id
).model_dump(mode="json"),
"failures": [],
"unhandled_exception": "Simulated unhandled error",
@ -446,12 +446,16 @@ def test_mock_connector_checkpoint_recovery(
initial_checkpoints = response.json()
# Verify we got the expected checkpoints in order
assert len(initial_checkpoints) > 0
assert (
initial_checkpoints[0]["checkpoint_content"] == {}
) # Initial empty checkpoint
assert initial_checkpoints[1]["checkpoint_content"] == {}
assert initial_checkpoints[2]["checkpoint_content"] == {}
assert len(initial_checkpoints) == 3
assert initial_checkpoints[0] == {
"has_more": True,
"last_document_id": None,
} # Initial empty checkpoint
assert initial_checkpoints[1] == {
"has_more": True,
"last_document_id": docs_batch_1[-1].id,
}
assert initial_checkpoints[2] == {"has_more": True, "last_document_id": doc2.id}
# Reset the mock server for the next run
response = mock_server_client.post("/reset")
@ -463,8 +467,8 @@ def test_mock_connector_checkpoint_recovery(
json=[
{
"documents": [doc3.model_dump(mode="json")],
"checkpoint": ConnectorCheckpoint(
checkpoint_content={}, has_more=False
"checkpoint": MockConnectorCheckpoint(
has_more=False, last_document_id=doc3.id
).model_dump(mode="json"),
"failures": [],
}
@ -515,4 +519,4 @@ def test_mock_connector_checkpoint_recovery(
# Verify the recovery run started from the last successful checkpoint
assert len(recovery_checkpoints) == 1
assert recovery_checkpoints[0]["checkpoint_content"] == {}
assert recovery_checkpoints[0] == {"has_more": True, "last_document_id": doc2.id}

View File

@ -34,6 +34,7 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
json={
"name": str(uuid.uuid4()),
"provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
@ -49,6 +50,9 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
assert provider_data["model_names"] == _DEFAULT_MODELS
assert provider_data["default_model_name"] == _DEFAULT_MODELS[0]
assert provider_data["display_model_names"] is None
assert (
provider_data["api_key"] == "sk-0****0000"
) # test that returned key is sanitized
def test_update_llm_provider_model_names(reset: None) -> None:
@ -64,10 +68,12 @@ def test_update_llm_provider_model_names(reset: None) -> None:
json={
"name": name,
"provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": [_DEFAULT_MODELS[0]],
"is_public": True,
"groups": [],
"api_key_changed": True,
},
)
assert response.status_code == 200
@ -81,6 +87,7 @@ def test_update_llm_provider_model_names(reset: None) -> None:
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"api_key": "sk-000000000000000000000000000000000000000000000001",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
@ -93,6 +100,30 @@ def test_update_llm_provider_model_names(reset: None) -> None:
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
assert provider_data["model_names"] == _DEFAULT_MODELS
assert (
provider_data["api_key"] == "sk-0****0000"
) # test that key was NOT updated due to api_key_changed not being set
# Update with api_key_changed properly set
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
json={
"id": created_provider["id"],
"name": name,
"provider": created_provider["provider"],
"api_key": "sk-000000000000000000000000000000000000000000000001",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,
"groups": [],
"api_key_changed": True,
},
)
assert response.status_code == 200
provider_data = _get_provider_by_id(admin_user, created_provider["id"])
assert provider_data is not None
assert provider_data["api_key"] == "sk-0****0001" # test that key was updated
def test_delete_llm_provider(reset: None) -> None:
@ -107,6 +138,7 @@ def test_delete_llm_provider(reset: None) -> None:
json={
"name": "test-provider-delete",
"provider": "openai",
"api_key": "sk-000000000000000000000000000000000000000000000000",
"default_model_name": _DEFAULT_MODELS[0],
"model_names": _DEFAULT_MODELS,
"is_public": True,

View File

@ -50,7 +50,7 @@ def answer_instance(
mocker: MockerFixture,
) -> Answer:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
"onyx.chat.answer.fast_gpu_status_request",
return_value=True,
)
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
@ -400,7 +400,7 @@ def test_no_slow_reranking(
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
"onyx.chat.answer.fast_gpu_status_request",
return_value=gpu_enabled,
)
rerank_settings = (

View File

@ -39,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
"onyx.chat.answer.fast_gpu_status_request",
return_value=True,
)
question = config["question"]

View File

@ -0,0 +1,436 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from jira import JIRA
from jira import JIRAError
from jira.resources import Issue
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.onyx_jira.connector import JiraConnector
from onyx.connectors.onyx_jira.connector import JiraConnectorCheckpoint
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
PAGE_SIZE = 2
@pytest.fixture
def jira_base_url() -> str:
return "https://jira.example.com"
@pytest.fixture
def project_key() -> str:
return "TEST"
@pytest.fixture
def mock_jira_client() -> MagicMock:
"""Create a mock JIRA client with proper typing"""
mock = MagicMock(spec=JIRA)
# Add proper return typing for search_issues method
mock.search_issues = MagicMock()
# Add proper return typing for project method
mock.project = MagicMock()
# Add proper return typing for projects method
mock.projects = MagicMock()
return mock
@pytest.fixture
def jira_connector(
jira_base_url: str, project_key: str, mock_jira_client: MagicMock
) -> Generator[JiraConnector, None, None]:
connector = JiraConnector(
jira_base_url=jira_base_url,
project_key=project_key,
comment_email_blacklist=["blacklist@example.com"],
labels_to_skip=["secret", "sensitive"],
)
connector._jira_client = mock_jira_client
connector._jira_client.client_info.return_value = jira_base_url
with patch("onyx.connectors.onyx_jira.connector._JIRA_FULL_PAGE_SIZE", 2):
yield connector
@pytest.fixture
def create_mock_issue() -> Callable[..., MagicMock]:
def _create_mock_issue(
key: str = "TEST-123",
summary: str = "Test Issue",
updated: str = "2023-01-01T12:00:00.000+0000",
description: str = "Test Description",
labels: list[str] | None = None,
) -> MagicMock:
"""Helper to create a mock Issue object"""
mock_issue = MagicMock(spec=Issue)
# Create fields attribute first
mock_issue.fields = MagicMock()
mock_issue.key = key
mock_issue.fields.summary = summary
mock_issue.fields.updated = updated
mock_issue.fields.description = description
mock_issue.fields.labels = labels or []
# Set up creator and assignee for testing owner extraction
mock_issue.fields.creator = MagicMock()
mock_issue.fields.creator.displayName = "Test Creator"
mock_issue.fields.creator.emailAddress = "creator@example.com"
mock_issue.fields.assignee = MagicMock()
mock_issue.fields.assignee.displayName = "Test Assignee"
mock_issue.fields.assignee.emailAddress = "assignee@example.com"
# Set up priority, status, and resolution
mock_issue.fields.priority = MagicMock()
mock_issue.fields.priority.name = "High"
mock_issue.fields.status = MagicMock()
mock_issue.fields.status.name = "In Progress"
mock_issue.fields.resolution = MagicMock()
mock_issue.fields.resolution.name = "Fixed"
# Add raw field for accessing through API version check
mock_issue.raw = {"fields": {"description": description}}
return mock_issue
return _create_mock_issue
def test_load_credentials(jira_connector: JiraConnector) -> None:
"""Test loading credentials"""
with patch(
"onyx.connectors.onyx_jira.connector.build_jira_client"
) as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
credentials = {
"jira_user_email": "user@example.com",
"jira_api_token": "token123",
}
result = jira_connector.load_credentials(credentials)
mock_build_client.assert_called_once_with(
credentials=credentials, jira_base=jira_connector.jira_base
)
assert result is None
assert jira_connector._jira_client == mock_build_client.return_value
def test_get_jql_query_with_project(jira_connector: JiraConnector) -> None:
"""Test JQL query generation with project specified"""
start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp()
query = jira_connector._get_jql_query(start, end)
# Check that the project part and time part are both in the query
assert f'project = "{jira_connector.jira_project}"' in query
assert "updated >= '2023-01-01 00:00'" in query
assert "updated <= '2023-01-02 00:00'" in query
assert " AND " in query
def test_get_jql_query_without_project(jira_base_url: str) -> None:
"""Test JQL query generation without project specified"""
# Create connector without project key
connector = JiraConnector(jira_base_url=jira_base_url)
start = datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime(2023, 1, 2, tzinfo=timezone.utc).timestamp()
query = connector._get_jql_query(start, end)
# Check that only time part is in the query
assert "project =" not in query
assert "updated >= '2023-01-01 00:00'" in query
assert "updated <= '2023-01-02 00:00'" in query
def test_load_from_checkpoint_happy_path(
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
) -> None:
"""Test loading from checkpoint - happy path"""
# Set up mocked issues
mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1")
mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2")
mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3")
# Only mock the search_issues method
jira_client = cast(JIRA, jira_connector._jira_client)
search_issues_mock = cast(MagicMock, jira_client.search_issues)
search_issues_mock.side_effect = [
[mock_issue1, mock_issue2],
[mock_issue3],
[],
]
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
# Check that the documents were returned
assert len(outputs) == 2
checkpoint_output1 = outputs[0]
assert len(checkpoint_output1.items) == 2
document1 = checkpoint_output1.items[0]
assert isinstance(document1, Document)
assert document1.id == "https://jira.example.com/browse/TEST-1"
document2 = checkpoint_output1.items[1]
assert isinstance(document2, Document)
assert document2.id == "https://jira.example.com/browse/TEST-2"
assert checkpoint_output1.next_checkpoint == JiraConnectorCheckpoint(
offset=2,
has_more=True,
)
checkpoint_output2 = outputs[1]
assert len(checkpoint_output2.items) == 1
document3 = checkpoint_output2.items[0]
assert isinstance(document3, Document)
assert document3.id == "https://jira.example.com/browse/TEST-3"
assert checkpoint_output2.next_checkpoint == JiraConnectorCheckpoint(
offset=3,
has_more=False,
)
# Check that search_issues was called with the right parameters
assert search_issues_mock.call_count == 2
args, kwargs = search_issues_mock.call_args_list[0]
assert kwargs["startAt"] == 0
assert kwargs["maxResults"] == PAGE_SIZE
args, kwargs = search_issues_mock.call_args_list[1]
assert kwargs["startAt"] == 2
assert kwargs["maxResults"] == PAGE_SIZE
def test_load_from_checkpoint_with_issue_processing_error(
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
) -> None:
"""Test loading from checkpoint with a mix of successful and failed issue processing across multiple batches"""
# Set up mocked issues for first batch
mock_issue1 = create_mock_issue(key="TEST-1", summary="Issue 1")
mock_issue2 = create_mock_issue(key="TEST-2", summary="Issue 2")
# Set up mocked issues for second batch
mock_issue3 = create_mock_issue(key="TEST-3", summary="Issue 3")
mock_issue4 = create_mock_issue(key="TEST-4", summary="Issue 4")
# Mock search_issues to return our mock issues in batches
jira_client = cast(JIRA, jira_connector._jira_client)
search_issues_mock = cast(MagicMock, jira_client.search_issues)
search_issues_mock.side_effect = [
[mock_issue1, mock_issue2], # First batch
[mock_issue3, mock_issue4], # Second batch
[], # Empty batch to indicate end
]
# Mock process_jira_issue to succeed for some issues and fail for others
def mock_process_side_effect(
jira_client: JIRA, issue: Issue, *args: Any, **kwargs: Any
) -> Document | None:
if issue.key in ["TEST-1", "TEST-3"]:
return Document(
id=f"https://jira.example.com/browse/{issue.key}",
sections=[],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",
metadata={},
)
else:
raise Exception(f"Processing error for {issue.key}")
with patch(
"onyx.connectors.onyx_jira.connector.process_jira_issue"
) as mock_process:
mock_process.side_effect = mock_process_side_effect
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
assert len(outputs) == 3
# Check first batch
first_batch = outputs[0]
assert len(first_batch.items) == 2
# First item should be successful
assert isinstance(first_batch.items[0], Document)
assert first_batch.items[0].id == "https://jira.example.com/browse/TEST-1"
# Second item should be a failure
assert isinstance(first_batch.items[1], ConnectorFailure)
assert first_batch.items[1].failed_document is not None
assert first_batch.items[1].failed_document.document_id == "TEST-2"
assert "Failed to process Jira issue" in first_batch.items[1].failure_message
# Check checkpoint indicates more items (full batch)
assert first_batch.next_checkpoint.has_more is True
assert first_batch.next_checkpoint.offset == 2
# Check second batch
second_batch = outputs[1]
assert len(second_batch.items) == 2
# First item should be successful
assert isinstance(second_batch.items[0], Document)
assert second_batch.items[0].id == "https://jira.example.com/browse/TEST-3"
# Second item should be a failure
assert isinstance(second_batch.items[1], ConnectorFailure)
assert second_batch.items[1].failed_document is not None
assert second_batch.items[1].failed_document.document_id == "TEST-4"
assert "Failed to process Jira issue" in second_batch.items[1].failure_message
# Check checkpoint indicates more items
assert second_batch.next_checkpoint.has_more is True
assert second_batch.next_checkpoint.offset == 4
# Check third, empty batch
third_batch = outputs[2]
assert len(third_batch.items) == 0
assert third_batch.next_checkpoint.has_more is False
assert third_batch.next_checkpoint.offset == 4
def test_load_from_checkpoint_with_skipped_issue(
jira_connector: JiraConnector, create_mock_issue: Callable[..., MagicMock]
) -> None:
"""Test loading from checkpoint with an issue that should be skipped due to labels"""
LABEL_TO_SKIP = "secret"
jira_connector.labels_to_skip = {LABEL_TO_SKIP}
# Set up mocked issue with a label to skip
mock_issue = create_mock_issue(
key="TEST-1", summary="Issue 1", labels=[LABEL_TO_SKIP]
)
# Mock search_issues to return our mock issue
jira_client = cast(JIRA, jira_connector._jira_client)
search_issues_mock = cast(MagicMock, jira_client.search_issues)
search_issues_mock.return_value = [mock_issue]
# Call load_from_checkpoint
end_time = time.time()
outputs = load_everything_from_checkpoint_connector(jira_connector, 0, end_time)
assert len(outputs) == 1
checkpoint_output = outputs[0]
# Check that no documents were returned
assert len(checkpoint_output.items) == 0
def test_retrieve_all_slim_documents(
jira_connector: JiraConnector, create_mock_issue: Any
) -> None:
"""Test retrieving all slim documents"""
# Set up mocked issues
mock_issue1 = create_mock_issue(key="TEST-1")
mock_issue2 = create_mock_issue(key="TEST-2")
# Mock search_issues to return our mock issues
jira_client = cast(JIRA, jira_connector._jira_client)
search_issues_mock = cast(MagicMock, jira_client.search_issues)
search_issues_mock.return_value = [mock_issue1, mock_issue2]
# Mock best_effort_get_field_from_issue to return the keys
with patch(
"onyx.connectors.onyx_jira.connector.best_effort_get_field_from_issue"
) as mock_field:
mock_field.side_effect = ["TEST-1", "TEST-2"]
# Mock build_jira_url to return URLs
with patch("onyx.connectors.onyx_jira.connector.build_jira_url") as mock_url:
mock_url.side_effect = [
"https://jira.example.com/browse/TEST-1",
"https://jira.example.com/browse/TEST-2",
]
# Call retrieve_all_slim_documents
batches = list(jira_connector.retrieve_all_slim_documents(0, 100))
# Check that a batch with 2 documents was returned
assert len(batches) == 1
assert len(batches[0]) == 2
assert isinstance(batches[0][0], SlimDocument)
assert batches[0][0].id == "https://jira.example.com/browse/TEST-1"
assert batches[0][1].id == "https://jira.example.com/browse/TEST-2"
# Check that search_issues was called with the right parameters
search_issues_mock.assert_called_once()
args, kwargs = search_issues_mock.call_args
assert kwargs["fields"] == "key"
@pytest.mark.parametrize(
"status_code,expected_exception,expected_message",
[
(
401,
CredentialExpiredError,
"Jira credential appears to be expired or invalid",
),
(
403,
InsufficientPermissionsError,
"Your Jira token does not have sufficient permissions",
),
(404, ConnectorValidationError, "Jira project not found"),
(
429,
ConnectorValidationError,
"Validation failed due to Jira rate-limits being exceeded",
),
],
)
def test_validate_connector_settings_errors(
jira_connector: JiraConnector,
status_code: int,
expected_exception: type[Exception],
expected_message: str,
) -> None:
"""Test validation with various error scenarios"""
error = JIRAError(status_code=status_code)
jira_client = cast(JIRA, jira_connector._jira_client)
project_mock = cast(MagicMock, jira_client.project)
project_mock.side_effect = error
with pytest.raises(expected_exception) as excinfo:
jira_connector.validate_connector_settings()
assert expected_message in str(excinfo.value)
def test_validate_connector_settings_with_project_success(
jira_connector: JiraConnector,
) -> None:
"""Test successful validation with project specified"""
jira_client = cast(JIRA, jira_connector._jira_client)
project_mock = cast(MagicMock, jira_client.project)
project_mock.return_value = MagicMock()
jira_connector.validate_connector_settings()
project_mock.assert_called_once_with(jira_connector.jira_project)
def test_validate_connector_settings_without_project_success(
jira_base_url: str,
) -> None:
"""Test successful validation without project specified"""
connector = JiraConnector(jira_base_url=jira_base_url)
connector._jira_client = MagicMock()
connector._jira_client.projects.return_value = [MagicMock()]
connector.validate_connector_settings()
connector._jira_client.projects.assert_called_once()

View File

@ -7,7 +7,8 @@ import pytest
from jira.resources import Issue
from pytest_mock import MockFixture
from onyx.connectors.onyx_jira.connector import fetch_jira_issues_batch
from onyx.connectors.onyx_jira.connector import _perform_jql_search
from onyx.connectors.onyx_jira.connector import process_jira_issue
@pytest.fixture
@ -79,14 +80,22 @@ def test_fetch_jira_issues_batch_small_ticket(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
# First get the issues via pagination
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
assert len(issues) == 1
# Then process each issue
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
docs = [doc for doc in docs if doc is not None] # Filter out None values
assert len(docs) == 1
assert docs[0].id.endswith("/SMALL-1")
assert docs[0].sections[0].text is not None
assert "Small description" in docs[0].sections[0].text
assert "Small comment 1" in docs[0].sections[0].text
assert "Small comment 2" in docs[0].sections[0].text
doc = docs[0]
assert doc is not None # Type assertion for mypy
assert doc.id.endswith("/SMALL-1")
assert doc.sections[0].text is not None
assert "Small description" in doc.sections[0].text
assert "Small comment 1" in doc.sections[0].text
assert "Small comment 2" in doc.sections[0].text
def test_fetch_jira_issues_batch_large_ticket(
@ -96,7 +105,13 @@ def test_fetch_jira_issues_batch_large_ticket(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_large]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
# First get the issues via pagination
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
assert len(issues) == 1
# Then process each issue
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
docs = [doc for doc in docs if doc is not None] # Filter out None values
assert len(docs) == 0 # The large ticket should be skipped
@ -109,10 +124,18 @@ def test_fetch_jira_issues_batch_mixed_tickets(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
# First get the issues via pagination
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
assert len(issues) == 2
# Then process each issue
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
docs = [doc for doc in docs if doc is not None] # Filter out None values
assert len(docs) == 1 # Only the small ticket should be included
assert docs[0].id.endswith("/SMALL-1")
doc = docs[0]
assert doc is not None # Type assertion for mypy
assert doc.id.endswith("/SMALL-1")
@patch("onyx.connectors.onyx_jira.connector.JIRA_CONNECTOR_MAX_TICKET_SIZE", 50)
@ -124,6 +147,12 @@ def test_fetch_jira_issues_batch_custom_size_limit(
) -> None:
mock_jira_client.search_issues.return_value = [mock_issue_small, mock_issue_large]
docs = list(fetch_jira_issues_batch(mock_jira_client, "project = TEST", 50))
# First get the issues via pagination
issues = list(_perform_jql_search(mock_jira_client, "project = TEST", 0, 50))
assert len(issues) == 2
# Then process each issue
docs = [process_jira_issue(mock_jira_client, issue) for issue in issues]
docs = [doc for doc in docs if doc is not None] # Filter out None values
assert len(docs) == 0 # Both tickets should be skipped due to the low size limit

View File

@ -0,0 +1,55 @@
from typing import cast
from typing import Generic
from typing import TypeVar
from pydantic import BaseModel
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
_ITERATION_LIMIT = 100_000
CT = TypeVar("CT", bound=ConnectorCheckpoint)
class SingleConnectorCallOutput(BaseModel, Generic[CT]):
items: list[Document | ConnectorFailure]
next_checkpoint: CT
def load_everything_from_checkpoint_connector(
connector: CheckpointConnector[CT],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> list[SingleConnectorCallOutput[CT]]:
num_iterations = 0
checkpoint = cast(CT, connector.build_dummy_checkpoint())
outputs: list[SingleConnectorCallOutput[CT]] = []
while checkpoint.has_more:
items: list[Document | ConnectorFailure] = []
doc_batch_generator = CheckpointOutputWrapper[CT]()(
connector.load_from_checkpoint(start, end, checkpoint)
)
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
items.append(failure)
if document is not None:
items.append(document)
if next_checkpoint is not None:
checkpoint = next_checkpoint
outputs.append(
SingleConnectorCallOutput(items=items, next_checkpoint=checkpoint)
)
num_iterations += 1
if num_iterations > _ITERATION_LIMIT:
raise RuntimeError("Too many iterations. Infinite loop?")
return outputs

View File

@ -1,10 +1,16 @@
import contextvars
import threading
import time
from collections.abc import Generator
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
import pytest
from onyx.utils.threadpool_concurrency import parallel_yield
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.threadpool_concurrency import ThreadSafeDict
from onyx.utils.threadpool_concurrency import wait_on_background
# Create a context variable for testing
@ -148,3 +154,237 @@ def test_multiple_background_tasks() -> None:
# Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s)
assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations
def test_thread_safe_dict_basic_operations() -> None:
"""Test basic operations of ThreadSafeDict"""
d = ThreadSafeDict[str, int]()
# Test setting and getting
d["a"] = 1
assert d["a"] == 1
# Test get with default
assert d.get("a", None) == 1
assert d.get("b", 2) == 2
# Test deletion
del d["a"]
assert "a" not in d
# Test length
d["x"] = 10
d["y"] = 20
assert len(d) == 2
# Test iteration
keys = sorted(d.keys())
assert keys == ["x", "y"]
# Test items and values
assert dict(d.items()) == {"x": 10, "y": 20}
assert sorted(d.values()) == [10, 20]
def test_thread_safe_dict_concurrent_access() -> None:
"""Test ThreadSafeDict with concurrent access from multiple threads"""
d = ThreadSafeDict[str, int]()
num_threads = 10
iterations = 1000
def increment_values() -> None:
for i in range(iterations):
key = str(i % 5) # Use 5 different keys
# Get current value or 0 if not exists, increment, then store
d[key] = d.get(key, 0) + 1
# Create and start threads
threads = []
for _ in range(num_threads):
t = threading.Thread(target=increment_values)
threads.append(t)
t.start()
# Wait for all threads to complete
for t in threads:
t.join()
# Verify results
# Each key should have been incremented (num_threads * iterations) / 5 times
expected_value = (num_threads * iterations) // 5
for i in range(5):
assert d[str(i)] == expected_value
def test_thread_safe_dict_bulk_operations() -> None:
"""Test bulk operations of ThreadSafeDict"""
d = ThreadSafeDict[str, int]()
# Test update with dict
d.update({"a": 1, "b": 2})
assert dict(d.items()) == {"a": 1, "b": 2}
# Test update with kwargs
d.update(c=3, d=4)
assert dict(d.items()) == {"a": 1, "b": 2, "c": 3, "d": 4}
# Test clear
d.clear()
assert len(d) == 0
def test_thread_safe_dict_concurrent_bulk_operations() -> None:
"""Test ThreadSafeDict with concurrent bulk operations"""
d = ThreadSafeDict[str, int]()
num_threads = 5
def bulk_update(start: int) -> None:
# Each thread updates with its own range of numbers
updates = {str(i): i for i in range(start, start + 20)}
d.update(updates)
time.sleep(0.01) # Add some delay to increase chance of thread overlap
# Run updates concurrently
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(bulk_update, i * 20) for i in range(num_threads)]
for future in futures:
future.result()
# Verify results
assert len(d) == num_threads * 20
# Verify all numbers from 0 to (num_threads * 20) are present
for i in range(num_threads * 20):
assert d[str(i)] == i
def test_thread_safe_dict_atomic_operations() -> None:
"""Test atomic operations with ThreadSafeDict's lock"""
d = ThreadSafeDict[str, list[int]]()
d["numbers"] = []
def append_numbers(start: int) -> None:
numbers = d["numbers"]
with d.lock:
for i in range(start, start + 5):
numbers.append(i)
time.sleep(0.001) # Add delay to increase chance of thread overlap
d["numbers"] = numbers
# Run concurrent append operations
threads = []
for i in range(4): # 4 threads, each adding 5 numbers
t = threading.Thread(target=append_numbers, args=(i * 5,))
threads.append(t)
t.start()
for t in threads:
t.join()
# Verify results
numbers = d["numbers"]
assert len(numbers) == 20 # 4 threads * 5 numbers each
assert sorted(numbers) == list(range(20)) # All numbers 0-19 should be present
def test_parallel_yield_basic() -> None:
"""Test that parallel_yield correctly yields values from multiple generators."""
def make_gen(values: list[int], delay: float) -> Generator[int, None, None]:
for v in values:
time.sleep(delay)
yield v
# Create generators with different delays
gen1 = make_gen([1, 4, 7], 0.1) # Slower generator
gen2 = make_gen([2, 5, 8], 0.05) # Faster generator
gen3 = make_gen([3, 6, 9], 0.15) # Slowest generator
# Collect results with timestamps
results: list[tuple[float, int]] = []
start_time = time.time()
for value in parallel_yield([gen1, gen2, gen3]):
results.append((time.time() - start_time, value))
# Verify all values were yielded
assert sorted(v for _, v in results) == list(range(1, 10))
# Verify that faster generators yielded earlier
# Group results by generator (values 1,4,7 are gen1, 2,5,8 are gen2, 3,6,9 are gen3)
gen1_times = [t for t, v in results if v in (1, 4, 7)]
gen2_times = [t for t, v in results if v in (2, 5, 8)]
gen3_times = [t for t, v in results if v in (3, 6, 9)]
# Average times for each generator
avg_gen1 = sum(gen1_times) / len(gen1_times)
avg_gen2 = sum(gen2_times) / len(gen2_times)
avg_gen3 = sum(gen3_times) / len(gen3_times)
# Verify gen2 (fastest) has lowest average time
assert avg_gen2 < avg_gen1
assert avg_gen2 < avg_gen3
def test_parallel_yield_empty_generators() -> None:
"""Test parallel_yield with empty generators."""
def empty_gen() -> Iterator[int]:
if False:
yield 0 # Makes this a generator function
gens = [empty_gen() for _ in range(3)]
results = list(parallel_yield(gens))
assert len(results) == 0
def test_parallel_yield_different_lengths() -> None:
"""Test parallel_yield with generators of different lengths."""
def make_gen(count: int) -> Iterator[int]:
for i in range(count):
yield i
time.sleep(0.01) # Small delay to ensure concurrent execution
gens = [
make_gen(1), # Yields: [0]
make_gen(3), # Yields: [0, 1, 2]
make_gen(2), # Yields: [0, 1]
]
results = list(parallel_yield(gens))
assert len(results) == 6 # Total number of items from all generators
assert sorted(results) == [0, 0, 0, 1, 1, 2]
def test_parallel_yield_exception_handling() -> None:
"""Test parallel_yield handles exceptions in generators properly."""
def failing_gen() -> Iterator[int]:
yield 1
raise ValueError("Generator failure")
def normal_gen() -> Iterator[int]:
yield 2
yield 3
gens = [failing_gen(), normal_gen()]
with pytest.raises(ValueError, match="Generator failure"):
list(parallel_yield(gens))
def test_parallel_yield_non_blocking() -> None:
"""Test parallel_yield with non-blocking generators (simple ranges)."""
def range_gen(start: int, end: int) -> Iterator[int]:
for i in range(start, end):
yield i
# Create three overlapping ranges
gens = [range_gen(0, 100), range_gen(100, 200), range_gen(200, 300)]
results = list(parallel_yield(gens))
# Verify no values are missing
assert len(results) == 300 # Should have all values from 0 to 299
assert sorted(results) == list(range(300))

View File

@ -61,7 +61,7 @@ import {
import { buildImgUrl } from "@/app/chat/files/images/utils";
import { useAssistants } from "@/components/context/AssistantsContext";
import { debounce } from "lodash";
import { FullLLMProvider } from "../configuration/llm/interfaces";
import { LLMProviderView } from "../configuration/llm/interfaces";
import StarterMessagesList from "./StarterMessageList";
import { Switch, SwitchField } from "@/components/ui/switch";
@ -123,7 +123,7 @@ export function AssistantEditor({
documentSets: DocumentSet[];
user: User | null;
defaultPublic: boolean;
llmProviders: FullLLMProvider[];
llmProviders: LLMProviderView[];
tools: ToolSnapshot[];
shouldAddAssistantToUserPreferences?: boolean;
admin?: boolean;

View File

@ -1,4 +1,4 @@
import { FullLLMProvider } from "../configuration/llm/interfaces";
import { LLMProviderView } from "../configuration/llm/interfaces";
import { Persona, StarterMessage } from "./interfaces";
interface PersonaUpsertRequest {
@ -319,7 +319,7 @@ export function checkPersonaRequiresImageGeneration(persona: Persona) {
}
export function providersContainImageGeneratingSupport(
providers: FullLLMProvider[]
providers: LLMProviderView[]
) {
return providers.some((provider) => provider.provider === "openai");
}

View File

@ -184,6 +184,10 @@ export function SlackChannelConfigFormFields({
name: channel.name,
value: channel.id,
}));
},
{
shouldRetryOnError: false, // don't spam the Slack API
dedupingInterval: 60000, // Limit re-fetching to once per minute
}
);

View File

@ -6,11 +6,9 @@ import { ErrorCallout } from "@/components/ErrorCallout";
import { ThreeDotsLoader } from "@/components/Loading";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import { usePopup } from "@/components/admin/connectors/Popup";
import Link from "next/link";
import { SlackChannelConfigsTable } from "./SlackChannelConfigsTable";
import { useSlackBot, useSlackChannelConfigsByBot } from "./hooks";
import { ExistingSlackBotForm } from "../SlackBotUpdateForm";
import { FiPlusSquare } from "react-icons/fi";
import { Separator } from "@/components/ui/separator";
function SlackBotEditPage({
@ -37,7 +35,11 @@ function SlackBotEditPage({
} = useSlackChannelConfigsByBot(Number(unwrappedParams["bot-id"]));
if (isSlackBotLoading || isSlackChannelConfigsLoading) {
return <ThreeDotsLoader />;
return (
<div className="flex justify-center items-center h-screen">
<ThreeDotsLoader />
</div>
);
}
if (slackBotError || !slackBot) {
@ -67,7 +69,7 @@ function SlackBotEditPage({
}
return (
<div className="container mx-auto">
<>
<InstantSSRAutoRefresh />
<BackButton routerOverride="/admin/bots" />
@ -86,8 +88,18 @@ function SlackBotEditPage({
setPopup={setPopup}
/>
</div>
</div>
</>
);
}
export default SlackBotEditPage;
export default function Page({
params,
}: {
params: Promise<{ "bot-id": string }>;
}) {
return (
<div className="container mx-auto">
<SlackBotEditPage params={params} />
</div>
);
}

View File

@ -1,5 +1,5 @@
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
import { Modal } from "@/components/Modal";
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
@ -19,7 +19,7 @@ function LLMProviderUpdateModal({
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void;
}) {
@ -61,7 +61,7 @@ function LLMProviderDisplay({
shouldMarkAsDefault,
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null | undefined;
existingLlmProvider: FullLLMProvider;
existingLlmProvider: LLMProviderView;
shouldMarkAsDefault?: boolean;
}) {
const [formIsVisible, setFormIsVisible] = useState(false);
@ -146,7 +146,7 @@ export function ConfiguredLLMProviderDisplay({
existingLlmProviders,
llmProviderDescriptors,
}: {
existingLlmProviders: FullLLMProvider[];
existingLlmProviders: LLMProviderView[];
llmProviderDescriptors: WellKnownLLMProviderDescriptor[];
}) {
existingLlmProviders = existingLlmProviders.sort((a, b) => {

View File

@ -21,7 +21,7 @@ import {
} from "@/components/admin/connectors/Field";
import { useState } from "react";
import { useSWRConfig } from "swr";
import { FullLLMProvider } from "./interfaces";
import { LLMProviderView } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@ -43,7 +43,7 @@ export function CustomLLMProviderUpdateForm({
hideSuccess,
}: {
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void;
hideSuccess?: boolean;
@ -165,7 +165,7 @@ export function CustomLLMProviderUpdateForm({
}
if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as FullLLMProvider;
const newLlmProvider = (await response.json()) as LLMProviderView;
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{

View File

@ -9,7 +9,7 @@ import Text from "@/components/ui/text";
import Title from "@/components/ui/title";
import { Button } from "@/components/ui/button";
import { ThreeDotsLoader } from "@/components/Loading";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { LLMProviderUpdateForm } from "./LLMProviderUpdateForm";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
@ -25,7 +25,7 @@ function LLMProviderUpdateModal({
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor | null;
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
setPopup?: (popup: PopupSpec) => void;
}) {
@ -99,7 +99,7 @@ function DefaultLLMProviderDisplay({
function AddCustomLLMProvider({
existingLlmProviders,
}: {
existingLlmProviders: FullLLMProvider[];
existingLlmProviders: LLMProviderView[];
}) {
const [formIsVisible, setFormIsVisible] = useState(false);
@ -130,7 +130,7 @@ export function LLMConfiguration() {
const { data: llmProviderDescriptors } = useSWR<
WellKnownLLMProviderDescriptor[]
>("/api/admin/llm/built-in/options", errorHandlingFetcher);
const { data: existingLlmProviders } = useSWR<FullLLMProvider[]>(
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
LLM_PROVIDERS_ADMIN_URL,
errorHandlingFetcher
);

View File

@ -14,7 +14,7 @@ import {
import { useState } from "react";
import { useSWRConfig } from "swr";
import { defaultModelsByProvider, getDisplayNameForModel } from "@/lib/hooks";
import { FullLLMProvider, WellKnownLLMProviderDescriptor } from "./interfaces";
import { LLMProviderView, WellKnownLLMProviderDescriptor } from "./interfaces";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import * as Yup from "yup";
import isEqual from "lodash/isEqual";
@ -31,7 +31,7 @@ export function LLMProviderUpdateForm({
}: {
llmProviderDescriptor: WellKnownLLMProviderDescriptor;
onClose: () => void;
existingLlmProvider?: FullLLMProvider;
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
hideAdvanced?: boolean;
setPopup?: (popup: PopupSpec) => void;
@ -73,6 +73,7 @@ export function LLMProviderUpdateForm({
defaultModelsByProvider[llmProviderDescriptor.name] ||
[],
deployment_name: existingLlmProvider?.deployment_name,
api_key_changed: false,
};
// Setup validation schema if required
@ -113,6 +114,7 @@ export function LLMProviderUpdateForm({
is_public: Yup.boolean().required(),
groups: Yup.array().of(Yup.number()),
display_model_names: Yup.array().of(Yup.string()),
api_key_changed: Yup.boolean(),
});
return (
@ -122,6 +124,8 @@ export function LLMProviderUpdateForm({
onSubmit={async (values, { setSubmitting }) => {
setSubmitting(true);
values.api_key_changed = values.api_key !== initialValues.api_key;
// test the configuration
if (!isEqual(values, initialValues)) {
setIsTesting(true);
@ -180,7 +184,7 @@ export function LLMProviderUpdateForm({
}
if (shouldMarkAsDefault) {
const newLlmProvider = (await response.json()) as FullLLMProvider;
const newLlmProvider = (await response.json()) as LLMProviderView;
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${newLlmProvider.id}/default`,
{

View File

@ -53,14 +53,14 @@ export interface LLMProvider {
is_default_vision_provider: boolean | null;
}
export interface FullLLMProvider extends LLMProvider {
export interface LLMProviderView extends LLMProvider {
id: number;
is_default_provider: boolean | null;
model_names: string[];
icon?: React.FC<{ size?: number; className?: string }>;
}
export interface VisionProvider extends FullLLMProvider {
export interface VisionProvider extends LLMProviderView {
vision_models: string[];
}

View File

@ -54,6 +54,7 @@ export const SourceCard: React.FC<{
<div className="flex items-center gap-1 mt-1">
<ResultIcon doc={document} size={18} />
<div className="text-text-700 text-xs leading-tight truncate flex-1 min-w-0">
{truncatedIdentifier}
</div>

View File

@ -49,7 +49,7 @@ export function SearchResultIcon({ url }: { url: string }) {
if (!faviconUrl) {
return <SourceIcon sourceType={ValidSources.Web} iconSize={18} />;
}
if (url.includes("docs.onyx.app")) {
if (url.includes("onyx.app")) {
return <OnyxIcon size={18} className="dark:text-[#fff] text-[#000]" />;
}

View File

@ -17,12 +17,11 @@ export function WebResultIcon({
try {
hostname = new URL(url).hostname;
} catch (e) {
// console.log(e);
hostname = "docs.onyx.app";
hostname = "onyx.app";
}
return (
<>
{hostname == "docs.onyx.app" ? (
{hostname.includes("onyx.app") ? (
<OnyxIcon size={size} className="dark:text-[#fff] text-[#000]" />
) : !error ? (
<img

View File

@ -26,35 +26,6 @@ export const ResultIcon = ({
);
};
// export default function SourceCard({
// doc,
// setPresentingDocument,
// }: {
// doc: OnyxDocument;
// setPresentingDocument?: (document: OnyxDocument) => void;
// }) {
// return (
// <div
// key={doc.document_id}
// onClick={() => openDocument(doc, setPresentingDocument)}
// className="cursor-pointer h-[80px] text-left overflow-hidden flex flex-col gap-0.5 rounded-lg px-3 py-2 bg-accent-background hover:bg-accent-background-hovered w-[200px]"
// >
// <div className="line-clamp-1 font-semibold text-ellipsis text-text-900 flex h-6 items-center gap-2 text-sm">
// {doc.is_internet || doc.source_type === "web" ? (
// <WebResultIcon url={doc.link} />
// ) : (
// <SourceIcon sourceType={doc.source_type} iconSize={18} />
// )}
// <p>{truncateString(doc.semantic_identifier || doc.document_id, 20)}</p>
// </div>
// <div className="line-clamp-2 text-sm font-semibold"></div>
// <div className="line-clamp-2 text-sm font-normal leading-snug text-text-700">
// {doc.blurb}
// </div>
// </div>
// );
// }
interface SeeMoreBlockProps {
toggleDocumentSelection: () => void;
docs: OnyxDocument[];

View File

@ -1,5 +1,5 @@
import {
FullLLMProvider,
LLMProviderView,
WellKnownLLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { User } from "@/lib/types";
@ -36,7 +36,7 @@ export async function checkLlmProvider(user: User | null) {
const [providerResponse, optionsResponse, defaultCheckResponse] =
await Promise.all(tasks);
let providers: FullLLMProvider[] = [];
let providers: LLMProviderView[] = [];
if (providerResponse?.ok) {
providers = await providerResponse.json();
}

View File

@ -3,7 +3,7 @@ import { CCPairBasicInfo, DocumentSet, User } from "../types";
import { getCurrentUserSS } from "../userSS";
import { fetchSS } from "../utilsSS";
import {
FullLLMProvider,
LLMProviderView,
getProviderIcon,
} from "@/app/admin/configuration/llm/interfaces";
import { ToolSnapshot } from "../tools/interfaces";
@ -16,7 +16,7 @@ export async function fetchAssistantEditorInfoSS(
{
ccPairs: CCPairBasicInfo[];
documentSets: DocumentSet[];
llmProviders: FullLLMProvider[];
llmProviders: LLMProviderView[];
user: User | null;
existingPersona: Persona | null;
tools: ToolSnapshot[];
@ -83,7 +83,7 @@ export async function fetchAssistantEditorInfoSS(
];
}
const llmProviders = (await llmProvidersResponse.json()) as FullLLMProvider[];
const llmProviders = (await llmProvidersResponse.json()) as LLMProviderView[];
if (personaId && personaResponse && !personaResponse.ok) {
return [null, `Failed to fetch Persona - ${await personaResponse.text()}`];