diff --git a/.github/workflows/pr-integration-tests.yml b/.github/workflows/pr-integration-tests.yml index 5573c51d9..593a9adfd 100644 --- a/.github/workflows/pr-integration-tests.yml +++ b/.github/workflows/pr-integration-tests.yml @@ -145,7 +145,7 @@ jobs: run: | cd deployment/docker_compose docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v - + # NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections - name: Start Docker containers run: | @@ -157,6 +157,7 @@ jobs: REQUIRE_EMAIL_VERIFICATION=false \ DISABLE_TELEMETRY=true \ IMAGE_TAG=test \ + INTEGRATION_TESTS_MODE=true \ docker compose -f docker-compose.dev.yml -p onyx-stack up -d id: start_docker @@ -199,7 +200,7 @@ jobs: cd backend/tests/integration/mock_services docker compose -f docker-compose.mock-it-services.yml \ -p mock-it-services-stack up -d - + # NOTE: Use pre-ping/null to reduce flakiness due to dropped connections - name: Run Standard Integration Tests run: | diff --git a/.github/workflows/pr-python-model-tests.yml b/.github/workflows/pr-python-model-tests.yml index a070eea27..0421e1228 100644 --- a/.github/workflows/pr-python-model-tests.yml +++ b/.github/workflows/pr-python-model-tests.yml @@ -1,10 +1,16 @@ -name: Connector Tests +name: Model Server Tests on: schedule: # This cron expression runs the job daily at 16:00 UTC (9am PT) - cron: "0 16 * * *" - + workflow_dispatch: + inputs: + branch: + description: 'Branch to run the workflow on' + required: false + default: 'main' + env: # Bedrock AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} @@ -26,6 +32,23 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_TOKEN }} + + # tag every docker image with "test" so that we can spin up the correct set + # of images during testing + + # We don't need to build the Web Docker image since it's not yet used + # in the integration tests. We have a separate action to verify that it builds + # successfully. + - name: Pull Model Server Docker image + run: | + docker pull onyxdotapp/onyx-model-server:latest + docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test + - name: Set up Python uses: actions/setup-python@v5 with: @@ -41,6 +64,49 @@ jobs: pip install --retries 5 --timeout 30 -r backend/requirements/default.txt pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt + - name: Start Docker containers + run: | + cd deployment/docker_compose + ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \ + AUTH_TYPE=basic \ + REQUIRE_EMAIL_VERIFICATION=false \ + DISABLE_TELEMETRY=true \ + IMAGE_TAG=test \ + docker compose -f docker-compose.dev.yml -p onyx-stack up -d indexing_model_server + id: start_docker + + - name: Wait for service to be ready + run: | + echo "Starting wait-for-service script..." + + start_time=$(date +%s) + timeout=300 # 5 minutes in seconds + + while true; do + current_time=$(date +%s) + elapsed_time=$((current_time - start_time)) + + if [ $elapsed_time -ge $timeout ]; then + echo "Timeout reached. Service did not become ready in 5 minutes." + exit 1 + fi + + # Use curl with error handling to ignore specific exit code 56 + response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error") + + if [ "$response" = "200" ]; then + echo "Service is ready!" + break + elif [ "$response" = "curl_error" ]; then + echo "Curl encountered an error, possibly exit code 56. Continuing to retry..." + else + echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..." + fi + + sleep 5 + done + echo "Finished waiting for service." + - name: Run Tests shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" run: | @@ -56,3 +122,10 @@ jobs: -H 'Content-type: application/json' \ --data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \ $SLACK_WEBHOOK + + - name: Stop Docker containers + if: always() + run: | + cd deployment/docker_compose + docker compose -f docker-compose.dev.yml -p onyx-stack down -v + diff --git a/README.md b/README.md index ea303c59c..442f8404f 100644 --- a/README.md +++ b/README.md @@ -26,12 +26,12 @@ [Onyx](https://www.onyx.app/) (formerly Danswer) is the AI platform connected to your company's docs, apps, and people. Onyx provides a feature rich Chat interface and plugs into any LLM of your choice. -There are over 40 supported connectors such as Google Drive, Slack, Confluence, Salesforce, etc. which keep knowledge and permissions up to date. -Create custom AI agents with unique prompts, knowledge, and actions the agents can take. +Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc. +Create custom AI agents with unique prompts, knowledge, and actions that the agents can take. Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud. -

Feature Showcase

+

Feature Highlights

**Deep research over your team's knowledge:** @@ -63,22 +63,21 @@ We also have built-in support for high-availability/scalable deployment on Kuber References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment). +## 🔍 Other Notable Benefits of Onyx +- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback. +- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc. +- Knowledge curation features like document-sets, query history, usage analytics, etc. +- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents. + + ## 🚧 Roadmap -- Extensions to the Chrome Plugin -- Latest methods in information retrieval (StructRAG, LightGraphRAG, etc.) +- New methods in information retrieval (StructRAG, LightGraphRAG, etc.) - Personalized Search - Organizational understanding and ability to locate and suggest experts from your team. - Code Search - SQL and Structured Query Language -## 🔍 Other Notable Benefits of Onyx -- Custom deep learning models only through Onyx + learn from user feedback. -- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc. -- Knowledge curation features like document-sets, query history, usage analytics, etc. -- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents. - - ## 🔌 Connectors Keep knowledge and access up to sync across 40+ connectors: diff --git a/backend/ee/onyx/server/query_and_chat/token_limit.py b/backend/ee/onyx/server/query_and_chat/token_limit.py index 5ee53b8f3..c6cd8486e 100644 --- a/backend/ee/onyx/server/query_and_chat/token_limit.py +++ b/backend/ee/onyx/server/query_and_chat/token_limit.py @@ -13,7 +13,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.api_key import is_api_key_email_address -from onyx.db.engine import get_session_with_tenant +from onyx.db.engine import get_session_with_current_tenant from onyx.db.models import ChatMessage from onyx.db.models import ChatSession from onyx.db.models import TokenRateLimit @@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel -def _check_token_rate_limits(user: User | None, tenant_id: str) -> None: +def _check_token_rate_limits(user: User | None) -> None: if user is None: # Unauthenticated users are only rate limited by global settings - _user_is_rate_limited_by_global(tenant_id) + _user_is_rate_limited_by_global() elif is_api_key_email_address(user.email): # API keys are only rate limited by global settings - _user_is_rate_limited_by_global(tenant_id) + _user_is_rate_limited_by_global() else: run_functions_tuples_in_parallel( [ - (_user_is_rate_limited, (user.id, tenant_id)), - (_user_is_rate_limited_by_group, (user.id, tenant_id)), - (_user_is_rate_limited_by_global, (tenant_id,)), + (_user_is_rate_limited, (user.id,)), + (_user_is_rate_limited_by_group, (user.id,)), + (_user_is_rate_limited_by_global, ()), ] ) @@ -52,8 +52,8 @@ User rate limits """ -def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None: - with get_session_with_tenant(tenant_id=tenant_id) as db_session: +def _user_is_rate_limited(user_id: UUID) -> None: + with get_session_with_current_tenant() as db_session: user_rate_limits = fetch_all_user_token_rate_limits( db_session=db_session, enabled_only=True, ordered=False ) @@ -93,8 +93,8 @@ User Group rate limits """ -def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None: - with get_session_with_tenant(tenant_id=tenant_id) as db_session: +def _user_is_rate_limited_by_group(user_id: UUID) -> None: + with get_session_with_current_tenant() as db_session: group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session) if group_rate_limits: diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 52e1ddf8b..8521cd001 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -98,12 +98,17 @@ class CloudEmbedding: return final_embeddings except Exception as e: error_string = ( - f"Error embedding text with OpenAI: {str(e)} \n" - f"Model: {model} \n" - f"Provider: {self.provider} \n" - f"Texts: {texts}" + f"Exception embedding text with OpenAI - {type(e)}: " + f"Model: {model} " + f"Provider: {self.provider} " + f"Exception: {e}" ) logger.error(error_string) + + # only log text when it's not an authentication error. + if not isinstance(e, openai.AuthenticationError): + logger.debug(f"Exception texts: {texts}") + raise RuntimeError(error_string) async def _embed_cohere( diff --git a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py index 50bdb5f32..2db26cfb6 100644 --- a/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py +++ b/backend/onyx/background/celery/tasks/external_group_syncing/tasks.py @@ -361,6 +361,7 @@ def connector_external_group_sync_generator_task( cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, + eager_load_credential=True, ) if cc_pair is None: raise ValueError( diff --git a/backend/onyx/background/indexing/run_indexing.py b/backend/onyx/background/indexing/run_indexing.py index 41823754a..9729f61dd 100644 --- a/backend/onyx/background/indexing/run_indexing.py +++ b/backend/onyx/background/indexing/run_indexing.py @@ -15,6 +15,7 @@ from onyx.background.indexing.memory_tracer import MemoryTracer from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL +from onyx.configs.app_configs import INTEGRATION_TESTS_MODE from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET from onyx.configs.constants import DocumentSource @@ -89,8 +90,8 @@ def _get_connector_runner( ) # validate the connector settings - - runnable_connector.validate_connector_settings() + if not INTEGRATION_TESTS_MODE: + runnable_connector.validate_connector_settings() except Exception as e: logger.exception(f"Unable to instantiate connector due to {e}") diff --git a/backend/onyx/chat/process_message.py b/backend/onyx/chat/process_message.py index a3cce6b7b..2bc43e368 100644 --- a/backend/onyx/chat/process_message.py +++ b/backend/onyx/chat/process_message.py @@ -747,14 +747,13 @@ def stream_chat_message_objects( files=latest_query_files, single_message_history=single_message_history, ), - system_message=default_build_system_message(prompt_config), + system_message=default_build_system_message(prompt_config, llm.config), message_history=message_history, llm_config=llm.config, raw_user_query=final_msg.message, raw_user_uploaded_files=latest_query_files or [], single_message_history=single_message_history, ) - prompt_builder.update_system_prompt(default_build_system_message(prompt_config)) # LLM prompt building, response capturing, etc. answer = Answer( @@ -870,7 +869,6 @@ def stream_chat_message_objects( for img in img_generation_response if img.image_data ], - tenant_id=tenant_id, ) info.ai_message_files.extend( [ diff --git a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py index c7cdec8f9..8e175f576 100644 --- a/backend/onyx/chat/prompt_builder/answer_prompt_builder.py +++ b/backend/onyx/chat/prompt_builder/answer_prompt_builder.py @@ -12,6 +12,7 @@ from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_toke from onyx.chat.prompt_builder.utils import translate_history_to_basemessages from onyx.file_store.models import InMemoryChatFile from onyx.llm.interfaces import LLMConfig +from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME from onyx.llm.models import PreviousMessage from onyx.llm.utils import build_content_with_imgs from onyx.llm.utils import check_message_tokens @@ -19,6 +20,7 @@ from onyx.llm.utils import message_to_prompt_and_imgs from onyx.llm.utils import model_supports_image_input from onyx.natural_language_processing.utils import get_tokenizer from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT +from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK from onyx.prompts.prompt_utils import drop_messages_history_overflow from onyx.prompts.prompt_utils import handle_onyx_date_awareness @@ -31,8 +33,16 @@ from onyx.tools.tool import Tool def default_build_system_message( prompt_config: PromptConfig, + llm_config: LLMConfig, ) -> SystemMessage | None: system_prompt = prompt_config.system_prompt.strip() + # See https://simonwillison.net/tags/markdown/ for context on this temporary fix + # for o-series markdown generation + if ( + llm_config.model_provider == OPENAI_PROVIDER_NAME + and llm_config.model_name.startswith("o") + ): + system_prompt = CODE_BLOCK_MARKDOWN + system_prompt tag_handled_prompt = handle_onyx_date_awareness( system_prompt, prompt_config, @@ -110,21 +120,8 @@ class AnswerPromptBuilder: ), ) - self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = ( - ( - system_message, - check_message_tokens(system_message, self.llm_tokenizer_encode_func), - ) - if system_message - else None - ) - self.user_message_and_token_cnt = ( - user_message, - check_message_tokens( - user_message, - self.llm_tokenizer_encode_func, - ), - ) + self.update_system_prompt(system_message) + self.update_user_prompt(user_message) self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = [] diff --git a/backend/onyx/configs/app_configs.py b/backend/onyx/configs/app_configs.py index 5e87734d2..ec0e4e2d5 100644 --- a/backend/onyx/configs/app_configs.py +++ b/backend/onyx/configs/app_configs.py @@ -158,7 +158,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres" POSTGRES_PASSWORD = urllib.parse.quote_plus( os.environ.get("POSTGRES_PASSWORD") or "password" ) -POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost" +POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1" POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432" POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres" AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2" @@ -626,6 +626,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE") DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true" +INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true" + MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH") TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true" diff --git a/backend/onyx/connectors/factory.py b/backend/onyx/connectors/factory.py index b4f497f65..c67b08ff8 100644 --- a/backend/onyx/connectors/factory.py +++ b/backend/onyx/connectors/factory.py @@ -3,6 +3,7 @@ from typing import Type from sqlalchemy.orm import Session +from onyx.configs.app_configs import INTEGRATION_TESTS_MODE from onyx.configs.constants import DocumentSource from onyx.configs.constants import DocumentSourceRequiringTenantContext from onyx.connectors.airtable.airtable_connector import AirtableConnector @@ -187,6 +188,9 @@ def validate_ccpair_for_user( user: User | None, tenant_id: str | None, ) -> None: + if INTEGRATION_TESTS_MODE: + return + # Validate the connector settings connector = fetch_connector_by_id(connector_id, db_session) credential = fetch_credential_by_id_for_user( @@ -199,7 +203,10 @@ def validate_ccpair_for_user( if not connector: raise ValueError("Connector not found") - if connector.source == DocumentSource.INGESTION_API: + if ( + connector.source == DocumentSource.INGESTION_API + or connector.source == DocumentSource.MOCK_CONNECTOR + ): return if not credential: diff --git a/backend/onyx/connectors/google_drive/connector.py b/backend/onyx/connectors/google_drive/connector.py index 1287a8960..46da04ff3 100644 --- a/backend/onyx/connectors/google_drive/connector.py +++ b/backend/onyx/connectors/google_drive/connector.py @@ -220,7 +220,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector): return self._creds def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: - self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] + try: + self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] + except KeyError: + raise ValueError( + "Primary admin email missing, " + "should not call this property " + "before calling load_credentials" + ) self._creds, new_creds_dict = get_google_creds( credentials=credentials, diff --git a/backend/onyx/db/connector_credential_pair.py b/backend/onyx/db/connector_credential_pair.py index 712e81894..7c73faaa7 100644 --- a/backend/onyx/db/connector_credential_pair.py +++ b/backend/onyx/db/connector_credential_pair.py @@ -194,9 +194,14 @@ def get_connector_credential_pair_from_id_for_user( def get_connector_credential_pair_from_id( db_session: Session, cc_pair_id: int, + eager_load_credential: bool = False, ) -> ConnectorCredentialPair | None: stmt = select(ConnectorCredentialPair).distinct() stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id) + + if eager_load_credential: + stmt = stmt.options(joinedload(ConnectorCredentialPair.credential)) + result = db_session.execute(stmt) return result.scalar_one_or_none() diff --git a/backend/onyx/file_store/utils.py b/backend/onyx/file_store/utils.py index 384990718..91198790a 100644 --- a/backend/onyx/file_store/utils.py +++ b/backend/onyx/file_store/utils.py @@ -8,7 +8,7 @@ import requests from sqlalchemy.orm import Session from onyx.configs.constants import FileOrigin -from onyx.db.engine import get_session_with_tenant +from onyx.db.engine import get_session_with_current_tenant from onyx.db.models import ChatMessage from onyx.file_store.file_store import get_default_file_store from onyx.file_store.models import FileDescriptor @@ -53,11 +53,11 @@ def load_all_chat_files( return files -def save_file_from_url(url: str, tenant_id: str) -> str: +def save_file_from_url(url: str) -> str: """NOTE: using multiple sessions here, since this is often called using multithreading. In practice, sharing a session has resulted in weird errors.""" - with get_session_with_tenant(tenant_id=tenant_id) as db_session: + with get_session_with_current_tenant() as db_session: response = requests.get(url) response.raise_for_status() @@ -75,8 +75,8 @@ def save_file_from_url(url: str, tenant_id: str) -> str: return unique_id -def save_file_from_base64(base64_string: str, tenant_id: str) -> str: - with get_session_with_tenant(tenant_id=tenant_id) as db_session: +def save_file_from_base64(base64_string: str) -> str: + with get_session_with_current_tenant() as db_session: unique_id = str(uuid4()) file_store = get_default_file_store(db_session) file_store.save_file( @@ -90,14 +90,12 @@ def save_file_from_base64(base64_string: str, tenant_id: str) -> str: def save_file( - tenant_id: str, url: str | None = None, base64_data: str | None = None, ) -> str: """Save a file from either a URL or base64 encoded string. Args: - tenant_id: The tenant ID to save the file under url: URL to download file from base64_data: Base64 encoded file data @@ -111,22 +109,22 @@ def save_file( raise ValueError("Cannot specify both url and base64_data") if url is not None: - return save_file_from_url(url, tenant_id) + return save_file_from_url(url) elif base64_data is not None: - return save_file_from_base64(base64_data, tenant_id) + return save_file_from_base64(base64_data) else: raise ValueError("Must specify either url or base64_data") -def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]: +def save_files(urls: list[str], base64_files: list[str]) -> list[str]: # NOTE: be explicit about typing so that if we change things, we get notified funcs: list[ tuple[ - Callable[[str, str | None, str | None], str], - tuple[str, str | None, str | None], + Callable[[str | None, str | None], str], + tuple[str | None, str | None], ] - ] = [(save_file, (tenant_id, url, None)) for url in urls] + [ - (save_file, (tenant_id, None, base64_file)) for base64_file in base64_files + ] = [(save_file, (url, None)) for url in urls] + [ + (save_file, (None, base64_file)) for base64_file in base64_files ] return run_functions_tuples_in_parallel(funcs) diff --git a/backend/onyx/prompts/chat_prompts.py b/backend/onyx/prompts/chat_prompts.py index aa13482b6..04cc33488 100644 --- a/backend/onyx/prompts/chat_prompts.py +++ b/backend/onyx/prompts/chat_prompts.py @@ -18,6 +18,7 @@ Remember to provide inline citations in the format [1], [2], [3], etc. ADDITIONAL_INFO = "\n\nAdditional Information:\n\t- {datetime_info}." +CODE_BLOCK_MARKDOWN = "Formatting re-enabled. " CHAT_USER_PROMPT = f""" Refer to the following context documents when responding to me.{{optional_ignore_statement}} diff --git a/backend/onyx/server/manage/slack_bot.py b/backend/onyx/server/manage/slack_bot.py index ae06bc813..53e0d0bbe 100644 --- a/backend/onyx/server/manage/slack_bot.py +++ b/backend/onyx/server/manage/slack_bot.py @@ -1,5 +1,3 @@ -from typing import Any - from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException @@ -345,6 +343,9 @@ def list_bot_configs( ] +MAX_CHANNELS = 200 + + @router.get( "/admin/slack-app/bots/{bot_id}/channels", ) @@ -353,38 +354,40 @@ def get_all_channels_from_slack_api( db_session: Session = Depends(get_session), _: User | None = Depends(current_admin_user), ) -> list[SlackChannel]: + """ + Fetches all channels from the Slack API. + If the workspace has 200 or more channels, we raise an error. + """ tokens = fetch_slack_bot_tokens(db_session, bot_id) if not tokens or "bot_token" not in tokens: raise HTTPException( status_code=404, detail="Bot token not found for the given bot ID" ) - bot_token = tokens["bot_token"] - client = WebClient(token=bot_token) + client = WebClient(token=tokens["bot_token"]) try: - channels = [] - cursor = None - while True: - response = client.conversations_list( - types="public_channel,private_channel", - exclude_archived=True, - limit=1000, - cursor=cursor, - ) - for channel in response["channels"]: - channels.append(SlackChannel(id=channel["id"], name=channel["name"])) + response = client.conversations_list( + types="public_channel,private_channel", + exclude_archived=True, + limit=MAX_CHANNELS, + ) - response_metadata: dict[str, Any] = response.get("response_metadata", {}) - if isinstance(response_metadata, dict): - cursor = response_metadata.get("next_cursor") - if not cursor: - break - else: - break + channels = [ + SlackChannel(id=channel["id"], name=channel["name"]) + for channel in response["channels"] + ] + + if len(channels) == MAX_CHANNELS: + raise HTTPException( + status_code=400, + detail=f"Workspace has {MAX_CHANNELS} or more channels.", + ) return channels + except SlackApiError as e: raise HTTPException( - status_code=500, detail=f"Error fetching channels from Slack API: {str(e)}" + status_code=500, + detail=f"Error fetching channels from Slack API: {str(e)}", ) diff --git a/backend/onyx/server/query_and_chat/token_limit.py b/backend/onyx/server/query_and_chat/token_limit.py index b94903a28..fc0bc629d 100644 --- a/backend/onyx/server/query_and_chat/token_limit.py +++ b/backend/onyx/server/query_and_chat/token_limit.py @@ -13,7 +13,6 @@ from sqlalchemy.orm import Session from onyx.auth.users import current_chat_accesssible_user from onyx.db.engine import get_session_context_manager -from onyx.db.engine import get_session_with_tenant from onyx.db.models import ChatMessage from onyx.db.models import ChatSession from onyx.db.models import TokenRateLimit @@ -21,7 +20,6 @@ from onyx.db.models import User from onyx.db.token_limit import fetch_all_global_token_rate_limits from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation -from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @@ -39,13 +37,13 @@ def check_token_rate_limits( return versioned_rate_limit_strategy = fetch_versioned_implementation( - "onyx.server.query_and_chat.token_limit", "_check_token_rate_limits" + "onyx.server.query_and_chat.token_limit", _check_token_rate_limits.__name__ ) - return versioned_rate_limit_strategy(user, get_current_tenant_id()) + return versioned_rate_limit_strategy(user) -def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None: - _user_is_rate_limited_by_global(tenant_id) +def _check_token_rate_limits(_: User | None) -> None: + _user_is_rate_limited_by_global() """ @@ -53,8 +51,8 @@ Global rate limits """ -def _user_is_rate_limited_by_global(tenant_id: str | None) -> None: - with get_session_with_tenant(tenant_id=tenant_id) as db_session: +def _user_is_rate_limited_by_global() -> None: + with get_session_context_manager() as db_session: global_rate_limits = fetch_all_global_token_rate_limits( db_session=db_session, enabled_only=True, ordered=False ) diff --git a/backend/onyx/utils/threadpool_concurrency.py b/backend/onyx/utils/threadpool_concurrency.py index f6a2b3fbe..4ef87348f 100644 --- a/backend/onyx/utils/threadpool_concurrency.py +++ b/backend/onyx/utils/threadpool_concurrency.py @@ -1,3 +1,4 @@ +import contextvars import threading import uuid from collections.abc import Callable @@ -14,10 +15,6 @@ logger = setup_logger() R = TypeVar("R") -# WARNING: it is not currently well understood whether we lose access to contextvars when functions are -# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless -# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or -# is not safe, update this comment. def run_functions_tuples_in_parallel( functions_with_args: list[tuple[Callable, tuple]], allow_failures: bool = False, @@ -45,8 +42,11 @@ def run_functions_tuples_in_parallel( results = [] with ThreadPoolExecutor(max_workers=workers) as executor: + # The primary reason for propagating contextvars is to allow acquiring a db session + # that respects tenant id. Context.run is expected to be low-overhead, but if we later + # find that it is increasing latency we can make using it optional. future_to_index = { - executor.submit(func, *args): i + executor.submit(contextvars.copy_context().run, func, *args): i for i, (func, args) in enumerate(functions_with_args) } @@ -83,10 +83,6 @@ class FunctionCall(Generic[R]): return self.func(*self.args, **self.kwargs) -# WARNING: it is not currently well understood whether we lose access to contextvars when functions are -# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless -# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or -# is not safe, update this comment. def run_functions_in_parallel( function_calls: list[FunctionCall], allow_failures: bool = False, @@ -102,7 +98,9 @@ def run_functions_in_parallel( with ThreadPoolExecutor(max_workers=len(function_calls)) as executor: future_to_id = { - executor.submit(func_call.execute): func_call.result_id + executor.submit( + contextvars.copy_context().run, func_call.execute + ): func_call.result_id for func_call in function_calls } @@ -143,10 +141,6 @@ class TimeoutThread(threading.Thread): ) -# WARNING: it is not currently well understood whether we lose access to contextvars when functions are -# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless -# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or -# is not safe, update this comment. def run_with_timeout( timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any ) -> R: @@ -154,7 +148,8 @@ def run_with_timeout( Executes a function with a timeout. If the function doesn't complete within the specified timeout, raises TimeoutError. """ - task = TimeoutThread(timeout, func, *args, **kwargs) + context = contextvars.copy_context() + task = TimeoutThread(timeout, context.run, func, *args, **kwargs) task.start() task.join(timeout) diff --git a/backend/scripts/debugging/onyx_redis.py b/backend/scripts/debugging/onyx_redis.py index 7d3334cf2..6b42a0846 100644 --- a/backend/scripts/debugging/onyx_redis.py +++ b/backend/scripts/debugging/onyx_redis.py @@ -3,6 +3,7 @@ import json import logging import sys import time +from enum import Enum from logging import getLogger from typing import cast from uuid import UUID @@ -20,10 +21,13 @@ from onyx.configs.app_configs import REDIS_PORT from onyx.configs.app_configs import REDIS_SSL from onyx.db.engine import get_session_with_tenant from onyx.db.users import get_user_by_email +from onyx.redis.redis_connector import RedisConnector +from onyx.redis.redis_connector_index import RedisConnectorIndex from onyx.redis.redis_pool import RedisPool from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR +from shared_configs.contextvars import get_current_tenant_id # Tool to run helpful operations on Redis in production # This is targeted for internal usage and may not have all the necessary parameters @@ -42,6 +46,19 @@ SCAN_ITER_COUNT = 10000 BATCH_DEFAULT = 1000 +class OnyxRedisCommand(Enum): + purge_connectorsync_taskset = "purge_connectorsync_taskset" + purge_documentset_taskset = "purge_documentset_taskset" + purge_usergroup_taskset = "purge_usergroup_taskset" + purge_locks_blocking_deletion = "purge_locks_blocking_deletion" + purge_vespa_syncing = "purge_vespa_syncing" + get_user_token = "get_user_token" + delete_user_token = "delete_user_token" + + def __str__(self) -> str: + return self.value + + def get_user_id(user_email: str) -> tuple[UUID, str]: tenant_id = ( get_tenant_id_for_email(user_email) if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA @@ -55,50 +72,79 @@ def get_user_id(user_email: str) -> tuple[UUID, str]: def onyx_redis( - command: str, + command: OnyxRedisCommand, batch: int, dry_run: bool, + ssl: bool, host: str, port: int, db: int, password: str | None, user_email: str | None = None, + cc_pair_id: int | None = None, ) -> int: + # this is global and not tenant aware pool = RedisPool.create_pool( host=host, port=port, db=db, password=password if password else "", - ssl=REDIS_SSL, + ssl=ssl, ssl_cert_reqs="optional", ssl_ca_certs=None, ) r = Redis(connection_pool=pool) + logger.info("Redis ping starting. This may hang if your settings are incorrect.") + try: r.ping() except: logger.exception("Redis ping exceptioned") raise - if command == "purge_connectorsync_taskset": + logger.info("Redis ping succeeded.") + + if command == OnyxRedisCommand.purge_connectorsync_taskset: """Purge connector tasksets. Used when the tasks represented in the tasksets have been purged.""" return purge_by_match_and_type( "*connectorsync_taskset*", "set", batch, dry_run, r ) - elif command == "purge_documentset_taskset": + elif command == OnyxRedisCommand.purge_documentset_taskset: return purge_by_match_and_type( "*documentset_taskset*", "set", batch, dry_run, r ) - elif command == "purge_usergroup_taskset": + elif command == OnyxRedisCommand.purge_usergroup_taskset: return purge_by_match_and_type("*usergroup_taskset*", "set", batch, dry_run, r) - elif command == "purge_vespa_syncing": + elif command == OnyxRedisCommand.purge_locks_blocking_deletion: + if cc_pair_id is None: + logger.error("You must specify --cc-pair with purge_deletion_locks") + return 1 + + tenant_id = get_current_tenant_id() + logger.info(f"Purging locks associated with deleting cc_pair={cc_pair_id}.") + redis_connector = RedisConnector(tenant_id, cc_pair_id) + + match_pattern = f"{tenant_id}:{RedisConnectorIndex.FENCE_PREFIX}_{cc_pair_id}/*" + purge_by_match_and_type(match_pattern, "string", batch, dry_run, r) + + redis_delete_if_exists_helper( + f"{tenant_id}:{redis_connector.prune.fence_key}", dry_run, r + ) + redis_delete_if_exists_helper( + f"{tenant_id}:{redis_connector.permissions.fence_key}", dry_run, r + ) + redis_delete_if_exists_helper( + f"{tenant_id}:{redis_connector.external_group_sync.fence_key}", dry_run, r + ) + return 0 + elif command == OnyxRedisCommand.purge_vespa_syncing: return purge_by_match_and_type( "*connectorsync:vespa_syncing*", "string", batch, dry_run, r ) - elif command == "get_user_token": + elif command == OnyxRedisCommand.get_user_token: if not user_email: logger.error("You must specify --user-email with get_user_token") return 1 @@ -109,7 +155,7 @@ def onyx_redis( else: print(f"No token found for user {user_email}") return 2 - elif command == "delete_user_token": + elif command == OnyxRedisCommand.delete_user_token: if not user_email: logger.error("You must specify --user-email with delete_user_token") return 1 @@ -131,6 +177,25 @@ def flush_batch_delete(batch_keys: list[bytes], r: Redis) -> None: pipe.execute() +def redis_delete_if_exists_helper(key: str, dry_run: bool, r: Redis) -> bool: + """Returns True if the key was found, False if not. + This function exists for logging purposes as the delete operation itself + doesn't really need to check the existence of the key. + """ + + if not r.exists(key): + logger.info(f"Did not find {key}.") + return False + + if dry_run: + logger.info(f"(DRY-RUN) Deleting {key}.") + else: + logger.info(f"Deleting {key}.") + r.delete(key) + + return True + + def purge_by_match_and_type( match_pattern: str, match_type: str, batch_size: int, dry_run: bool, r: Redis ) -> int: @@ -138,6 +203,12 @@ def purge_by_match_and_type( match_type: https://redis.io/docs/latest/commands/type/ """ + logger.info( + f"purge_by_match_and_type start: " + f"match_pattern={match_pattern} " + f"match_type={match_type}" + ) + # cursor = "0" # while cursor != 0: # cursor, data = self.scan( @@ -164,13 +235,15 @@ def purge_by_match_and_type( logger.info(f"Deleting item {count}: {key_str}") batch_keys.append(key) + + # flush if batch size has been reached if len(batch_keys) >= batch_size: flush_batch_delete(batch_keys, r) batch_keys.clear() - if len(batch_keys) >= batch_size: - flush_batch_delete(batch_keys, r) - batch_keys.clear() + # final flush + flush_batch_delete(batch_keys, r) + batch_keys.clear() logger.info(f"Deleted {count} matches.") @@ -279,7 +352,21 @@ def delete_user_token_from_redis( if __name__ == "__main__": parser = argparse.ArgumentParser(description="Onyx Redis Manager") - parser.add_argument("--command", type=str, help="Operation to run", required=True) + parser.add_argument( + "--command", + type=OnyxRedisCommand, + help="The command to run", + choices=list(OnyxRedisCommand), + required=True, + ) + + parser.add_argument( + "--ssl", + type=bool, + default=REDIS_SSL, + help="Use SSL when connecting to Redis. Usually True for prod and False for local testing", + required=False, + ) parser.add_argument( "--host", @@ -342,6 +429,13 @@ if __name__ == "__main__": required=False, ) + parser.add_argument( + "--cc-pair", + type=int, + help="A connector credential pair id. Used with the purge_deletion_locks command.", + required=False, + ) + args = parser.parse_args() if args.tenant_id: @@ -368,10 +462,12 @@ if __name__ == "__main__": command=args.command, batch=args.batch, dry_run=args.dry_run, + ssl=args.ssl, host=args.host, port=args.port, db=args.db, password=args.password, user_email=args.user_email, + cc_pair_id=args.cc_pair, ) sys.exit(exitcode) diff --git a/backend/tests/integration/common_utils/constants.py b/backend/tests/integration/common_utils/constants.py index c6731e739..5f0247852 100644 --- a/backend/tests/integration/common_utils/constants.py +++ b/backend/tests/integration/common_utils/constants.py @@ -3,7 +3,7 @@ import os ADMIN_USER_NAME = "admin_user" API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http" -API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost" +API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "127.0.0.1" API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080" API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}" MAX_DELAY = 45 diff --git a/backend/tests/integration/common_utils/managers/connector.py b/backend/tests/integration/common_utils/managers/connector.py index 04dd37c2a..410182a92 100644 --- a/backend/tests/integration/common_utils/managers/connector.py +++ b/backend/tests/integration/common_utils/managers/connector.py @@ -30,8 +30,10 @@ class ConnectorManager: name=name, source=source, input_type=input_type, - connector_specific_config=connector_specific_config - or {"file_locations": []}, + connector_specific_config=( + connector_specific_config + or ({"file_locations": []} if source == DocumentSource.FILE else {}) + ), access_type=access_type, groups=groups or [], ) diff --git a/backend/tests/integration/common_utils/managers/user.py b/backend/tests/integration/common_utils/managers/user.py index 0045e7995..950336f11 100644 --- a/backend/tests/integration/common_utils/managers/user.py +++ b/backend/tests/integration/common_utils/managers/user.py @@ -88,8 +88,6 @@ class UserManager: if not session_cookie: raise Exception("Failed to login") - print(f"Logged in as {test_user.email}") - # Set cookies in the headers test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; " test_user.cookies = {"fastapiusersauth": session_cookie} diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index 175a3d58a..8e2a5f448 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -70,7 +70,7 @@ def _answer_fixture_impl( files=[], single_message_history=None, ), - system_message=default_build_system_message(prompt_config), + system_message=default_build_system_message(prompt_config, mock_llm.config), message_history=[], llm_config=mock_llm.config, raw_user_query=QUERY, diff --git a/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py b/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py new file mode 100644 index 000000000..4d6d9a6a3 --- /dev/null +++ b/backend/tests/unit/onyx/utils/test_threadpool_contextvars.py @@ -0,0 +1,131 @@ +import contextvars +import time + +from onyx.utils.threadpool_concurrency import FunctionCall +from onyx.utils.threadpool_concurrency import run_functions_in_parallel +from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel +from onyx.utils.threadpool_concurrency import run_with_timeout + +# Create a test contextvar +test_var = contextvars.ContextVar("test_var", default="default") + + +def get_contextvar_value() -> str: + """Helper function that runs in a thread and returns the contextvar value""" + # Add a small sleep to ensure we're actually running in a different thread + time.sleep(0.1) + return test_var.get() + + +def test_run_with_timeout_preserves_contextvar() -> None: + """Test that run_with_timeout preserves contextvar values""" + # Set a value in the main thread + test_var.set("test_value") + + # Run function with timeout and verify the value is preserved + result = run_with_timeout(1.0, get_contextvar_value) + assert result == "test_value" + + +def test_run_functions_in_parallel_preserves_contextvar() -> None: + """Test that run_functions_in_parallel preserves contextvar values""" + # Set a value in the main thread + test_var.set("parallel_test") + + # Create multiple function calls + function_calls = [ + FunctionCall(get_contextvar_value), + FunctionCall(get_contextvar_value), + ] + + # Run in parallel and verify all results have the correct value + results = run_functions_in_parallel(function_calls) + + for result_id, value in results.items(): + assert value == "parallel_test" + + +def test_run_functions_tuples_preserves_contextvar() -> None: + """Test that run_functions_tuples_in_parallel preserves contextvar values""" + # Set a value in the main thread + test_var.set("tuple_test") + + # Create list of function tuples + functions_with_args = [ + (get_contextvar_value, ()), + (get_contextvar_value, ()), + ] + + # Run in parallel and verify all results have the correct value + results = run_functions_tuples_in_parallel(functions_with_args) + + for result in results: + assert result == "tuple_test" + + +def test_nested_contextvar_modifications() -> None: + """Test that modifications to contextvars in threads don't affect other threads""" + + def modify_and_return_contextvar(new_value: str) -> tuple[str, str]: + """Helper that modifies the contextvar and returns both values""" + original = test_var.get() + test_var.set(new_value) + time.sleep(0.1) # Ensure threads overlap + return original, test_var.get() + + # Set initial value + test_var.set("initial") + + # Run multiple functions that modify the contextvar + functions_with_args = [ + (modify_and_return_contextvar, ("thread1",)), + (modify_and_return_contextvar, ("thread2",)), + ] + + results = run_functions_tuples_in_parallel(functions_with_args) + + # Verify each thread saw the initial value and its own modification + for original, modified in results: + assert original == "initial" # Each thread should see the initial value + assert modified in [ + "thread1", + "thread2", + ] # Each thread should see its own modification + + # Verify the main thread's value wasn't affected + assert test_var.get() == "initial" + + +def test_contextvar_isolation_between_runs() -> None: + """Test that contextvar changes don't leak between separate parallel runs""" + + def set_and_return_contextvar(value: str) -> str: + test_var.set(value) + return test_var.get() + + # First run + test_var.set("first_run") + first_results = run_functions_tuples_in_parallel( + [ + (set_and_return_contextvar, ("thread1",)), + (set_and_return_contextvar, ("thread2",)), + ] + ) + + # Verify first run results + assert all(result in ["thread1", "thread2"] for result in first_results) + + # Second run should still see the main thread's value + assert test_var.get() == "first_run" + + # Second run with different value + test_var.set("second_run") + second_results = run_functions_tuples_in_parallel( + [ + (set_and_return_contextvar, ("thread3",)), + (set_and_return_contextvar, ("thread4",)), + ] + ) + + # Verify second run results + assert all(result in ["thread3", "thread4"] for result in second_results) diff --git a/deployment/data/nginx/run-nginx.sh b/deployment/data/nginx/run-nginx.sh index 5f18b0d6b..01f9c1497 100755 --- a/deployment/data/nginx/run-nginx.sh +++ b/deployment/data/nginx/run-nginx.sh @@ -1,5 +1,5 @@ # fill in the template -envsubst '$SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/conf.d/$1" > /etc/nginx/conf.d/app.conf +envsubst '$DOMAIN $SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/conf.d/$1" > /etc/nginx/conf.d/app.conf # wait for the api_server to be ready echo "Waiting for API server to boot up; this may take a minute or two..." diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 14860827c..525246e54 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -36,6 +36,7 @@ services: - OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-} - TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-} - CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-} + - INTEGRATION_TESTS_MODE=${INTEGRATION_TESTS_MODE:-} # Gen AI Settings - GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-} - QA_TIMEOUT=${QA_TIMEOUT:-} diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index b1892ddff..69cc80088 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -113,7 +113,6 @@ export function AssistantEditor({ documentSets, user, defaultPublic, - redirectType, llmProviders, tools, shouldAddAssistantToUserPreferences, @@ -124,7 +123,6 @@ export function AssistantEditor({ documentSets: DocumentSet[]; user: User | null; defaultPublic: boolean; - redirectType: SuccessfulPersonaUpdateRedirectType; llmProviders: FullLLMProvider[]; tools: ToolSnapshot[]; shouldAddAssistantToUserPreferences?: boolean; @@ -502,7 +500,7 @@ export function AssistantEditor({ ) .map((message: { message: string; name?: string }) => ({ message: message.message, - name: message.name || message.message, + name: message.message, })); // don't set groups if marked as public diff --git a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx index f4bfff0d5..503771d98 100644 --- a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx +++ b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm.tsx @@ -1,7 +1,7 @@ "use client"; -import React, { useMemo, useState, useEffect } from "react"; -import { Formik, Form, Field } from "formik"; +import React, { useMemo } from "react"; +import { Formik, Form } from "formik"; import * as Yup from "yup"; import { usePopup } from "@/components/admin/connectors/Popup"; import { @@ -13,17 +13,13 @@ import { createSlackChannelConfig, isPersonaASlackBotPersona, updateSlackChannelConfig, - fetchSlackChannels, } from "../lib"; import CardSection from "@/components/admin/CardSection"; import { useRouter } from "next/navigation"; import { Persona } from "@/app/admin/assistants/interfaces"; import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE"; -import { SEARCH_TOOL_ID, SEARCH_TOOL_NAME } from "@/app/chat/tools/constants"; -import { - SlackChannelConfigFormFields, - SlackChannelConfigFormFieldsProps, -} from "./SlackChannelConfigFormFields"; +import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants"; +import { SlackChannelConfigFormFields } from "./SlackChannelConfigFormFields"; export const SlackChannelConfigCreationForm = ({ slack_bot_id, diff --git a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx index 3372756e9..13a80f6bb 100644 --- a/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx +++ b/web/src/app/admin/bots/[bot-id]/channels/SlackChannelConfigFormFields.tsx @@ -1,13 +1,7 @@ "use client"; import React, { useState, useEffect, useMemo } from "react"; -import { - FieldArray, - Form, - useFormikContext, - ErrorMessage, - Field, -} from "formik"; +import { FieldArray, useFormikContext, ErrorMessage, Field } from "formik"; import { CCPairDescriptor, DocumentSet } from "@/lib/types"; import { Label, @@ -18,14 +12,13 @@ import { } from "@/components/admin/connectors/Field"; import { Button } from "@/components/ui/button"; import { Persona } from "@/app/admin/assistants/interfaces"; -import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; import CollapsibleSection from "@/app/admin/assistants/CollapsibleSection"; import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE"; import { StandardAnswerCategoryDropdownField } from "@/components/standardAnswers/StandardAnswerCategoryDropdown"; import { RadioGroup } from "@/components/ui/radio-group"; import { RadioGroupItemField } from "@/components/ui/RadioGroupItemField"; -import { AlertCircle, View } from "lucide-react"; +import { AlertCircle } from "lucide-react"; import { useRouter } from "next/navigation"; import { Tooltip, @@ -50,6 +43,7 @@ import { import { Separator } from "@/components/ui/separator"; import { CheckFormField } from "@/components/ui/CheckField"; +import { Input } from "@/components/ui/input"; export interface SlackChannelConfigFormFieldsProps { isUpdate: boolean; @@ -178,9 +172,13 @@ export function SlackChannelConfigFormFields({ ); }, [documentSets]); - const { data: channelOptions, isLoading } = useSWR( + const { + data: channelOptions, + error, + isLoading, + } = useSWR( `/api/manage/admin/slack-app/bots/${slack_bot_id}/channels`, - async (url: string) => { + async () => { const channels = await fetchSlackChannels(slack_bot_id); return channels.map((channel: any) => ({ name: channel.name, @@ -227,20 +225,34 @@ export function SlackChannelConfigFormFields({ > Select A Slack Channel: {" "} - - {({ field, form }: { field: any; form: any }) => ( - { - form.setFieldValue("channel_name", selected.name); - }} - initialSearchTerm={field.value} - onSearchTermChange={(term) => { - form.setFieldValue("channel_name", term); - }} + {error ? ( +
+
+ {error.message || "Unable to fetch Slack channels."} + {" Please enter the channel name manually."} +
+ - )} - +
+ ) : ( + + {({ field, form }: { field: any; form: any }) => ( + { + form.setFieldValue("channel_name", selected.name); + }} + initialSearchTerm={field.value} + onSearchTermChange={(term) => { + form.setFieldValue("channel_name", term); + }} + /> + )} + + )} )}
diff --git a/web/src/app/assistants/edit/[id]/page.tsx b/web/src/app/assistants/edit/[id]/page.tsx index 4213ab9fd..2826b3a94 100644 --- a/web/src/app/assistants/edit/[id]/page.tsx +++ b/web/src/app/assistants/edit/[id]/page.tsx @@ -21,11 +21,7 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
- +
diff --git a/web/src/app/assistants/new/page.tsx b/web/src/app/assistants/new/page.tsx index 1f831cb5c..f50b91bae 100644 --- a/web/src/app/assistants/new/page.tsx +++ b/web/src/app/assistants/new/page.tsx @@ -26,7 +26,6 @@ export default async function Page() { diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index a1e5fd2b7..b285592c1 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -47,6 +47,7 @@ import { removeMessage, sendMessage, setMessageAsLatest, + updateLlmOverrideForChatSession, updateParentChildren, uploadFilesForChat, useScrollonStream, @@ -65,7 +66,7 @@ import { import { usePopup } from "@/components/admin/connectors/Popup"; import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams"; import { useDocumentSelection } from "./useDocumentSelection"; -import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks"; +import { LlmDescriptor, useFilters, useLlmManager } from "@/lib/hooks"; import { ChatState, FeedbackType, RegenerationState } from "./types"; import { DocumentResults } from "./documentSidebar/DocumentResults"; import { OnyxInitializingLoader } from "@/components/OnyxInitializingLoader"; @@ -89,7 +90,11 @@ import { import { buildFilters } from "@/lib/search/utils"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import Dropzone from "react-dropzone"; -import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils"; +import { + checkLLMSupportsImageInput, + getFinalLLM, + structureValue, +} from "@/lib/llm/utils"; import { ChatInputBar } from "./input/ChatInputBar"; import { useChatContext } from "@/components/context/ChatContext"; import { v4 as uuidv4 } from "uuid"; @@ -194,16 +199,6 @@ export function ChatPage({ return screenSize; } - const { height: screenHeight } = useScreenSize(); - - const getContainerHeight = () => { - if (autoScrollEnabled) return undefined; - - if (screenHeight < 600) return "20vh"; - if (screenHeight < 1200) return "30vh"; - return "40vh"; - }; - // handle redirect if chat page is disabled // NOTE: this must be done here, in a client component since // settings are passed in via Context and therefore aren't @@ -222,6 +217,7 @@ export function ChatPage({ setProSearchEnabled(!proSearchEnabled); }; + const isInitialLoad = useRef(true); const [userSettingsToggled, setUserSettingsToggled] = useState(false); const { @@ -356,7 +352,7 @@ export function ChatPage({ ] ); - const llmOverrideManager = useLlmOverride( + const llmManager = useLlmManager( llmProviders, selectedChatSession, liveAssistant @@ -520,8 +516,17 @@ export function ChatPage({ scrollInitialized.current = false; if (!hasPerformedInitialScroll) { + if (isInitialLoad.current) { + setHasPerformedInitialScroll(true); + isInitialLoad.current = false; + } clientScrollToBottom(); + + setTimeout(() => { + setHasPerformedInitialScroll(true); + }, 100); } else if (isChatSessionSwitch) { + setHasPerformedInitialScroll(true); clientScrollToBottom(true); } @@ -1130,6 +1135,56 @@ export function ChatPage({ }); }; const [uncaughtError, setUncaughtError] = useState(null); + const [agenticGenerating, setAgenticGenerating] = useState(false); + + const autoScrollEnabled = + (user?.preferences?.auto_scroll && !agenticGenerating) ?? false; + + useScrollonStream({ + chatState: currentSessionChatState, + scrollableDivRef, + scrollDist, + endDivRef, + debounceNumber, + mobile: settings?.isMobile, + enableAutoScroll: autoScrollEnabled, + }); + + // Track whether a message has been sent during this page load, keyed by chat session id + const [sessionHasSentLocalUserMessage, setSessionHasSentLocalUserMessage] = + useState>(new Map()); + + // Update the local state for a session once the user sends a message + const markSessionMessageSent = (sessionId: string | null) => { + setSessionHasSentLocalUserMessage((prev) => { + const newMap = new Map(prev); + newMap.set(sessionId, true); + return newMap; + }); + }; + const currentSessionHasSentLocalUserMessage = useMemo( + () => (sessionId: string | null) => { + return sessionHasSentLocalUserMessage.size === 0 + ? undefined + : sessionHasSentLocalUserMessage.get(sessionId) || false; + }, + [sessionHasSentLocalUserMessage] + ); + + const { height: screenHeight } = useScreenSize(); + + const getContainerHeight = useMemo(() => { + return () => { + if (!currentSessionHasSentLocalUserMessage(chatSessionIdRef.current)) { + return undefined; + } + if (autoScrollEnabled) return undefined; + + if (screenHeight < 600) return "40vh"; + if (screenHeight < 1200) return "50vh"; + return "60vh"; + }; + }, [autoScrollEnabled, screenHeight, currentSessionHasSentLocalUserMessage]); const onSubmit = async ({ messageIdToResend, @@ -1138,7 +1193,7 @@ export function ChatPage({ forceSearch, isSeededChat, alternativeAssistantOverride = null, - modelOverRide, + modelOverride, regenerationRequest, overrideFileDescriptors, }: { @@ -1148,7 +1203,7 @@ export function ChatPage({ forceSearch?: boolean; isSeededChat?: boolean; alternativeAssistantOverride?: Persona | null; - modelOverRide?: LlmOverride; + modelOverride?: LlmDescriptor; regenerationRequest?: RegenerationRequest | null; overrideFileDescriptors?: FileDescriptor[]; } = {}) => { @@ -1156,6 +1211,9 @@ export function ChatPage({ let frozenSessionId = currentSessionId(); updateCanContinue(false, frozenSessionId); + // Mark that we've sent a message for this session in the current page load + markSessionMessageSent(frozenSessionId); + if (currentChatState() != "input") { if (currentChatState() == "uploading") { setPopup({ @@ -1191,6 +1249,22 @@ export function ChatPage({ currChatSessionId = chatSessionIdRef.current as string; } frozenSessionId = currChatSessionId; + // update the selected model for the chat session if one is specified so that + // it persists across page reloads. Do not `await` here so that the message + // request can continue and this will just happen in the background. + // NOTE: only set the model override for the chat session once we send a + // message with it. If the user switches models and then starts a new + // chat session, it is unexpected for that model to be used when they + // return to this session the next day. + let finalLLM = modelOverride || llmManager.currentLlm; + updateLlmOverrideForChatSession( + currChatSessionId, + structureValue( + finalLLM.name || "", + finalLLM.provider || "", + finalLLM.modelName || "" + ) + ); updateStatesWithNewSessionId(currChatSessionId); @@ -1250,11 +1324,14 @@ export function ChatPage({ : null) || (messageMap.size === 1 ? Array.from(messageMap.values())[0] : null); - const currentAssistantId = alternativeAssistantOverride - ? alternativeAssistantOverride.id - : alternativeAssistant - ? alternativeAssistant.id - : liveAssistant.id; + let currentAssistantId; + if (alternativeAssistantOverride) { + currentAssistantId = alternativeAssistantOverride.id; + } else if (alternativeAssistant) { + currentAssistantId = alternativeAssistant.id; + } else { + currentAssistantId = liveAssistant.id; + } resetInputBar(); let messageUpdates: Message[] | null = null; @@ -1326,15 +1403,13 @@ export function ChatPage({ forceSearch, regenerate: regenerationRequest !== undefined, modelProvider: - modelOverRide?.name || - llmOverrideManager.llmOverride.name || - undefined, + modelOverride?.name || llmManager.currentLlm.name || undefined, modelVersion: - modelOverRide?.modelName || - llmOverrideManager.llmOverride.modelName || + modelOverride?.modelName || + llmManager.currentLlm.modelName || searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || undefined, - temperature: llmOverrideManager.temperature || undefined, + temperature: llmManager.temperature || undefined, systemPromptOverride: searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined, useExistingUserMessage: isSeededChat, @@ -1802,7 +1877,7 @@ export function ChatPage({ const [_, llmModel] = getFinalLLM( llmProviders, liveAssistant, - llmOverrideManager.llmOverride + llmManager.currentLlm ); const llmAcceptsImages = checkLLMSupportsImageInput(llmModel); @@ -1857,7 +1932,6 @@ export function ChatPage({ // Used to maintain a "time out" for history sidebar so our existing refs can have time to process change const [untoggled, setUntoggled] = useState(false); const [loadingError, setLoadingError] = useState(null); - const [agenticGenerating, setAgenticGenerating] = useState(false); const explicitlyUntoggle = () => { setShowHistorySidebar(false); @@ -1899,19 +1973,6 @@ export function ChatPage({ isAnonymousUser: user?.is_anonymous_user, }); - const autoScrollEnabled = - (user?.preferences?.auto_scroll && !agenticGenerating) ?? false; - - useScrollonStream({ - chatState: currentSessionChatState, - scrollableDivRef, - scrollDist, - endDivRef, - debounceNumber, - mobile: settings?.isMobile, - enableAutoScroll: autoScrollEnabled, - }); - // Virtualization + Scrolling related effects and functions const scrollInitialized = useRef(false); interface VisibleRange { @@ -2121,7 +2182,7 @@ export function ChatPage({ }, [searchParams, router]); useEffect(() => { - llmOverrideManager.updateImageFilesPresent(imageFileInMessageHistory); + llmManager.updateImageFilesPresent(imageFileInMessageHistory); }, [imageFileInMessageHistory]); const pathname = usePathname(); @@ -2175,9 +2236,9 @@ export function ChatPage({ function createRegenerator(regenerationRequest: RegenerationRequest) { // Returns new function that only needs `modelOverRide` to be specified when called - return async function (modelOverRide: LlmOverride) { + return async function (modelOverride: LlmDescriptor) { return await onSubmit({ - modelOverRide, + modelOverride, messageIdToResend: regenerationRequest.parentMessage.messageId, regenerationRequest, forceSearch: regenerationRequest.forceSearch, @@ -2258,9 +2319,7 @@ export function ChatPage({ {(settingsToggled || userSettingsToggled) && ( - llmOverrideManager.updateLLMOverride(newOverride) - } + setCurrentLlm={(newLlm) => llmManager.updateCurrentLlm(newLlm)} defaultModel={user?.preferences.default_model!} llmProviders={llmProviders} onClose={() => { @@ -2324,7 +2383,7 @@ export function ChatPage({ setSharedChatSession(null)} @@ -2342,7 +2401,7 @@ export function ChatPage({ setSharingModalVisible(false)} @@ -2572,6 +2631,7 @@ export function ChatPage({ style={{ overflowAnchor: "none" }} key={currentSessionId()} className={ + (hasPerformedInitialScroll ? "" : " hidden ") + "desktop:-ml-4 w-full mx-auto " + "absolute mobile:top-0 desktop:top-0 left-0 " + (settings?.enterpriseSettings @@ -3058,7 +3118,7 @@ export function ChatPage({ messageId: message.messageId, parentMessage: parentMessage!, forceSearch: true, - })(llmOverrideManager.llmOverride); + })(llmManager.currentLlm); } else { setPopup({ type: "error", @@ -3203,7 +3263,7 @@ export function ChatPage({ availableDocumentSets={documentSets} availableTags={tags} filterManager={filterManager} - llmOverrideManager={llmOverrideManager} + llmManager={llmManager} removeDocs={() => { clearSelectedDocuments(); }} diff --git a/web/src/app/chat/RegenerateOption.tsx b/web/src/app/chat/RegenerateOption.tsx index f947ebf78..1265db2eb 100644 --- a/web/src/app/chat/RegenerateOption.tsx +++ b/web/src/app/chat/RegenerateOption.tsx @@ -1,8 +1,8 @@ import { useChatContext } from "@/components/context/ChatContext"; import { getDisplayNameForModel, - LlmOverride, - useLlmOverride, + LlmDescriptor, + useLlmManager, } from "@/lib/hooks"; import { StringOrNumberOption } from "@/components/Dropdown"; @@ -106,13 +106,13 @@ export default function RegenerateOption({ onDropdownVisibleChange, }: { selectedAssistant: Persona; - regenerate: (modelOverRide: LlmOverride) => Promise; + regenerate: (modelOverRide: LlmDescriptor) => Promise; overriddenModel?: string; onHoverChange: (isHovered: boolean) => void; onDropdownVisibleChange: (isVisible: boolean) => void; }) { const { llmProviders } = useChatContext(); - const llmOverrideManager = useLlmOverride(llmProviders); + const llmManager = useLlmManager(llmProviders); const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null); @@ -148,7 +148,7 @@ export default function RegenerateOption({ ); const currentModelName = - llmOverrideManager?.llmOverride.modelName || + llmManager?.currentLlm.modelName || (selectedAssistant ? selectedAssistant.llm_model_version_override || llmName : llmName); diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index 4010be443..f8a5c34aa 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -6,7 +6,7 @@ import { Persona } from "@/app/admin/assistants/interfaces"; import LLMPopover from "./LLMPopover"; import { InputPrompt } from "@/app/chat/interfaces"; -import { FilterManager, LlmOverrideManager } from "@/lib/hooks"; +import { FilterManager, LlmManager } from "@/lib/hooks"; import { useChatContext } from "@/components/context/ChatContext"; import { ChatFileType, FileDescriptor } from "../interfaces"; import { @@ -180,7 +180,7 @@ interface ChatInputBarProps { setMessage: (message: string) => void; stopGenerating: () => void; onSubmit: () => void; - llmOverrideManager: LlmOverrideManager; + llmManager: LlmManager; chatState: ChatState; alternativeAssistant: Persona | null; // assistants @@ -225,7 +225,7 @@ export function ChatInputBar({ availableSources, availableDocumentSets, availableTags, - llmOverrideManager, + llmManager, proSearchEnabled, setProSearchEnabled, }: ChatInputBarProps) { @@ -781,7 +781,7 @@ export function ChatInputBar({ diff --git a/web/src/app/chat/input/LLMPopover.tsx b/web/src/app/chat/input/LLMPopover.tsx index ad7e18e8e..1a4b6ab08 100644 --- a/web/src/app/chat/input/LLMPopover.tsx +++ b/web/src/app/chat/input/LLMPopover.tsx @@ -16,7 +16,7 @@ import { LLMProviderDescriptor, } from "@/app/admin/configuration/llm/interfaces"; import { Persona } from "@/app/admin/assistants/interfaces"; -import { LlmOverrideManager } from "@/lib/hooks"; +import { LlmManager } from "@/lib/hooks"; import { Tooltip, @@ -31,21 +31,19 @@ import { useUser } from "@/components/user/UserProvider"; interface LLMPopoverProps { llmProviders: LLMProviderDescriptor[]; - llmOverrideManager: LlmOverrideManager; + llmManager: LlmManager; requiresImageGeneration?: boolean; currentAssistant?: Persona; } export default function LLMPopover({ llmProviders, - llmOverrideManager, + llmManager, requiresImageGeneration, currentAssistant, }: LLMPopoverProps) { const [isOpen, setIsOpen] = useState(false); const { user } = useUser(); - const { llmOverride, updateLLMOverride } = llmOverrideManager; - const currentLlm = llmOverride.modelName; const llmOptionsByProvider: { [provider: string]: { @@ -93,19 +91,19 @@ export default function LLMPopover({ : null; const [localTemperature, setLocalTemperature] = useState( - llmOverrideManager.temperature ?? 0.5 + llmManager.temperature ?? 0.5 ); useEffect(() => { - setLocalTemperature(llmOverrideManager.temperature ?? 0.5); - }, [llmOverrideManager.temperature]); + setLocalTemperature(llmManager.temperature ?? 0.5); + }, [llmManager.temperature]); const handleTemperatureChange = (value: number[]) => { setLocalTemperature(value[0]); }; const handleTemperatureChangeComplete = (value: number[]) => { - llmOverrideManager.updateTemperature(value[0]); + llmManager.updateTemperature(value[0]); }; return ( @@ -120,15 +118,15 @@ export default function LLMPopover({ toggle flexPriority="stiff" name={getDisplayNameForModel( - llmOverrideManager?.llmOverride.modelName || + llmManager?.currentLlm.modelName || defaultModelDisplayName || "Models" )} Icon={getProviderIcon( - llmOverrideManager?.llmOverride.provider || + llmManager?.currentLlm.provider || defaultProvider?.provider || "anthropic", - llmOverrideManager?.llmOverride.modelName || + llmManager?.currentLlm.modelName || defaultProvider?.default_model_name || "claude-3-5-sonnet-20240620" )} @@ -147,12 +145,12 @@ export default function LLMPopover({