mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 19:38:19 +02:00
Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/schema-translate-map
This commit is contained in:
commit
23bdff6e21
5
.github/workflows/pr-integration-tests.yml
vendored
5
.github/workflows/pr-integration-tests.yml
vendored
@ -145,7 +145,7 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@ -157,6 +157,7 @@ jobs:
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
@ -199,7 +200,7 @@ jobs:
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
|
77
.github/workflows/pr-python-model-tests.yml
vendored
77
.github/workflows/pr-python-model-tests.yml
vendored
@ -1,10 +1,16 @@
|
||||
name: Connector Tests
|
||||
name: Model Server Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to run the workflow on'
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
env:
|
||||
# Bedrock
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
@ -26,6 +32,23 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Model Server Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-model-server:latest
|
||||
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@ -41,6 +64,49 @@ jobs:
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
@ -56,3 +122,10 @@ jobs:
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
|
||||
|
23
README.md
23
README.md
@ -26,12 +26,12 @@
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
|
||||
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
|
||||
There are over 40 supported connectors such as Google Drive, Slack, Confluence, Salesforce, etc. which keep knowledge and permissions up to date.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions the agents can take.
|
||||
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
|
||||
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
|
||||
|
||||
|
||||
<h3>Feature Showcase</h3>
|
||||
<h3>Feature Highlights</h3>
|
||||
|
||||
**Deep research over your team's knowledge:**
|
||||
|
||||
@ -63,22 +63,21 @@ We also have built-in support for high-availability/scalable deployment on Kuber
|
||||
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
|
||||
|
||||
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
|
||||
## 🚧 Roadmap
|
||||
- Extensions to the Chrome Plugin
|
||||
- Latest methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- Personalized Search
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
- Code Search
|
||||
- SQL and Structured Query Language
|
||||
|
||||
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models only through Onyx + learn from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
|
||||
## 🔌 Connectors
|
||||
Keep knowledge and access up to sync across 40+ connectors:
|
||||
|
||||
|
@ -13,7 +13,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import TokenRateLimit
|
||||
@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
|
||||
def _check_token_rate_limits(user: User | None) -> None:
|
||||
if user is None:
|
||||
# Unauthenticated users are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
elif is_api_key_email_address(user.email):
|
||||
# API keys are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
else:
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(_user_is_rate_limited, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_global, (tenant_id,)),
|
||||
(_user_is_rate_limited, (user.id,)),
|
||||
(_user_is_rate_limited_by_group, (user.id,)),
|
||||
(_user_is_rate_limited_by_global, ()),
|
||||
]
|
||||
)
|
||||
|
||||
@ -52,8 +52,8 @@ User rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def _user_is_rate_limited(user_id: UUID) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
@ -93,8 +93,8 @@ User Group rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||
|
||||
if group_rate_limits:
|
||||
|
@ -98,12 +98,17 @@ class CloudEmbedding:
|
||||
return final_embeddings
|
||||
except Exception as e:
|
||||
error_string = (
|
||||
f"Error embedding text with OpenAI: {str(e)} \n"
|
||||
f"Model: {model} \n"
|
||||
f"Provider: {self.provider} \n"
|
||||
f"Texts: {texts}"
|
||||
f"Exception embedding text with OpenAI - {type(e)}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {self.provider} "
|
||||
f"Exception: {e}"
|
||||
)
|
||||
logger.error(error_string)
|
||||
|
||||
# only log text when it's not an authentication error.
|
||||
if not isinstance(e, openai.AuthenticationError):
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
async def _embed_cohere(
|
||||
|
@ -361,6 +361,7 @@ def connector_external_group_sync_generator_task(
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
|
@ -15,6 +15,7 @@ from onyx.background.indexing.memory_tracer import MemoryTracer
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@ -89,8 +90,8 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
# validate the connector settings
|
||||
|
||||
runnable_connector.validate_connector_settings()
|
||||
if not INTEGRATION_TESTS_MODE:
|
||||
runnable_connector.validate_connector_settings()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
@ -747,14 +747,13 @@ def stream_chat_message_objects(
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config),
|
||||
system_message=default_build_system_message(prompt_config, llm.config),
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
raw_user_query=final_msg.message,
|
||||
raw_user_uploaded_files=latest_query_files or [],
|
||||
single_message_history=single_message_history,
|
||||
)
|
||||
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
@ -870,7 +869,6 @@ def stream_chat_message_objects(
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
|
@ -12,6 +12,7 @@ from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_toke
|
||||
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import check_message_tokens
|
||||
@ -19,6 +20,7 @@ from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
@ -31,8 +33,16 @@ from onyx.tools.tool import Tool
|
||||
|
||||
def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
|
||||
# for o-series markdown generation
|
||||
if (
|
||||
llm_config.model_provider == OPENAI_PROVIDER_NAME
|
||||
and llm_config.model_name.startswith("o")
|
||||
):
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt,
|
||||
prompt_config,
|
||||
@ -110,21 +120,8 @@ class AnswerPromptBuilder:
|
||||
),
|
||||
)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
|
||||
(
|
||||
system_message,
|
||||
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
if system_message
|
||||
else None
|
||||
)
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(
|
||||
user_message,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
self.update_system_prompt(system_message)
|
||||
self.update_user_prompt(user_message)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
|
@ -158,7 +158,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
@ -626,6 +626,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
|
||||
|
||||
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
|
||||
|
||||
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
|
||||
|
@ -3,6 +3,7 @@ from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
@ -187,6 +188,9 @@ def validate_ccpair_for_user(
|
||||
user: User | None,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
if INTEGRATION_TESTS_MODE:
|
||||
return
|
||||
|
||||
# Validate the connector settings
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id_for_user(
|
||||
@ -199,7 +203,10 @@ def validate_ccpair_for_user(
|
||||
if not connector:
|
||||
raise ValueError("Connector not found")
|
||||
|
||||
if connector.source == DocumentSource.INGESTION_API:
|
||||
if (
|
||||
connector.source == DocumentSource.INGESTION_API
|
||||
or connector.source == DocumentSource.MOCK_CONNECTOR
|
||||
):
|
||||
return
|
||||
|
||||
if not credential:
|
||||
|
@ -220,7 +220,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return self._creds
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
try:
|
||||
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
|
||||
self._creds, new_creds_dict = get_google_creds(
|
||||
credentials=credentials,
|
||||
|
@ -194,9 +194,14 @@ def get_connector_credential_pair_from_id_for_user(
|
||||
def get_connector_credential_pair_from_id(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
eager_load_credential: bool = False,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
|
||||
if eager_load_credential:
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.credential))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
@ -8,7 +8,7 @@ import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@ -53,11 +53,11 @@ def load_all_chat_files(
|
||||
return files
|
||||
|
||||
|
||||
def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
def save_file_from_url(url: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
weird errors."""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
@ -75,8 +75,8 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
return unique_id
|
||||
|
||||
|
||||
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def save_file_from_base64(base64_string: str) -> str:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
unique_id = str(uuid4())
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
@ -90,14 +90,12 @@ def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
|
||||
|
||||
|
||||
def save_file(
|
||||
tenant_id: str,
|
||||
url: str | None = None,
|
||||
base64_data: str | None = None,
|
||||
) -> str:
|
||||
"""Save a file from either a URL or base64 encoded string.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to save the file under
|
||||
url: URL to download file from
|
||||
base64_data: Base64 encoded file data
|
||||
|
||||
@ -111,22 +109,22 @@ def save_file(
|
||||
raise ValueError("Cannot specify both url and base64_data")
|
||||
|
||||
if url is not None:
|
||||
return save_file_from_url(url, tenant_id)
|
||||
return save_file_from_url(url)
|
||||
elif base64_data is not None:
|
||||
return save_file_from_base64(base64_data, tenant_id)
|
||||
return save_file_from_base64(base64_data)
|
||||
else:
|
||||
raise ValueError("Must specify either url or base64_data")
|
||||
|
||||
|
||||
def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]:
|
||||
def save_files(urls: list[str], base64_files: list[str]) -> list[str]:
|
||||
# NOTE: be explicit about typing so that if we change things, we get notified
|
||||
funcs: list[
|
||||
tuple[
|
||||
Callable[[str, str | None, str | None], str],
|
||||
tuple[str, str | None, str | None],
|
||||
Callable[[str | None, str | None], str],
|
||||
tuple[str | None, str | None],
|
||||
]
|
||||
] = [(save_file, (tenant_id, url, None)) for url in urls] + [
|
||||
(save_file, (tenant_id, None, base64_file)) for base64_file in base64_files
|
||||
] = [(save_file, (url, None)) for url in urls] + [
|
||||
(save_file, (None, base64_file)) for base64_file in base64_files
|
||||
]
|
||||
|
||||
return run_functions_tuples_in_parallel(funcs)
|
||||
|
@ -18,6 +18,7 @@ Remember to provide inline citations in the format [1], [2], [3], etc.
|
||||
|
||||
ADDITIONAL_INFO = "\n\nAdditional Information:\n\t- {datetime_info}."
|
||||
|
||||
CODE_BLOCK_MARKDOWN = "Formatting re-enabled. "
|
||||
|
||||
CHAT_USER_PROMPT = f"""
|
||||
Refer to the following context documents when responding to me.{{optional_ignore_statement}}
|
||||
|
@ -1,5 +1,3 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
@ -345,6 +343,9 @@ def list_bot_configs(
|
||||
]
|
||||
|
||||
|
||||
MAX_CHANNELS = 200
|
||||
|
||||
|
||||
@router.get(
|
||||
"/admin/slack-app/bots/{bot_id}/channels",
|
||||
)
|
||||
@ -353,38 +354,40 @@ def get_all_channels_from_slack_api(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[SlackChannel]:
|
||||
"""
|
||||
Fetches all channels from the Slack API.
|
||||
If the workspace has 200 or more channels, we raise an error.
|
||||
"""
|
||||
tokens = fetch_slack_bot_tokens(db_session, bot_id)
|
||||
if not tokens or "bot_token" not in tokens:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Bot token not found for the given bot ID"
|
||||
)
|
||||
|
||||
bot_token = tokens["bot_token"]
|
||||
client = WebClient(token=bot_token)
|
||||
client = WebClient(token=tokens["bot_token"])
|
||||
|
||||
try:
|
||||
channels = []
|
||||
cursor = None
|
||||
while True:
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=1000,
|
||||
cursor=cursor,
|
||||
)
|
||||
for channel in response["channels"]:
|
||||
channels.append(SlackChannel(id=channel["id"], name=channel["name"]))
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=MAX_CHANNELS,
|
||||
)
|
||||
|
||||
response_metadata: dict[str, Any] = response.get("response_metadata", {})
|
||||
if isinstance(response_metadata, dict):
|
||||
cursor = response_metadata.get("next_cursor")
|
||||
if not cursor:
|
||||
break
|
||||
else:
|
||||
break
|
||||
channels = [
|
||||
SlackChannel(id=channel["id"], name=channel["name"])
|
||||
for channel in response["channels"]
|
||||
]
|
||||
|
||||
if len(channels) == MAX_CHANNELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Workspace has {MAX_CHANNELS} or more channels.",
|
||||
)
|
||||
|
||||
return channels
|
||||
|
||||
except SlackApiError as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error fetching channels from Slack API: {str(e)}"
|
||||
status_code=500,
|
||||
detail=f"Error fetching channels from Slack API: {str(e)}",
|
||||
)
|
||||
|
@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import TokenRateLimit
|
||||
@ -21,7 +20,6 @@ from onyx.db.models import User
|
||||
from onyx.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -39,13 +37,13 @@ def check_token_rate_limits(
|
||||
return
|
||||
|
||||
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
||||
"onyx.server.query_and_chat.token_limit", "_check_token_rate_limits"
|
||||
"onyx.server.query_and_chat.token_limit", _check_token_rate_limits.__name__
|
||||
)
|
||||
return versioned_rate_limit_strategy(user, get_current_tenant_id())
|
||||
return versioned_rate_limit_strategy(user)
|
||||
|
||||
|
||||
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
def _check_token_rate_limits(_: User | None) -> None:
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
|
||||
"""
|
||||
@ -53,8 +51,8 @@ Global rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def _user_is_rate_limited_by_global() -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
global_rate_limits = fetch_all_global_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
@ -14,10 +15,6 @@ logger = setup_logger()
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
allow_failures: bool = False,
|
||||
@ -45,8 +42,11 @@ def run_functions_tuples_in_parallel(
|
||||
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# The primary reason for propagating contextvars is to allow acquiring a db session
|
||||
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
|
||||
# find that it is increasing latency we can make using it optional.
|
||||
future_to_index = {
|
||||
executor.submit(func, *args): i
|
||||
executor.submit(contextvars.copy_context().run, func, *args): i
|
||||
for i, (func, args) in enumerate(functions_with_args)
|
||||
}
|
||||
|
||||
@ -83,10 +83,6 @@ class FunctionCall(Generic[R]):
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_functions_in_parallel(
|
||||
function_calls: list[FunctionCall],
|
||||
allow_failures: bool = False,
|
||||
@ -102,7 +98,9 @@ def run_functions_in_parallel(
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
|
||||
future_to_id = {
|
||||
executor.submit(func_call.execute): func_call.result_id
|
||||
executor.submit(
|
||||
contextvars.copy_context().run, func_call.execute
|
||||
): func_call.result_id
|
||||
for func_call in function_calls
|
||||
}
|
||||
|
||||
@ -143,10 +141,6 @@ class TimeoutThread(threading.Thread):
|
||||
)
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_with_timeout(
|
||||
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
) -> R:
|
||||
@ -154,7 +148,8 @@ def run_with_timeout(
|
||||
Executes a function with a timeout. If the function doesn't complete within the specified
|
||||
timeout, raises TimeoutError.
|
||||
"""
|
||||
task = TimeoutThread(timeout, func, *args, **kwargs)
|
||||
context = contextvars.copy_context()
|
||||
task = TimeoutThread(timeout, context.run, func, *args, **kwargs)
|
||||
task.start()
|
||||
task.join(timeout)
|
||||
|
||||
|
@ -3,6 +3,7 @@ import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from logging import getLogger
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
@ -20,10 +21,13 @@ from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.configs.app_configs import REDIS_SSL
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_pool import RedisPool
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Tool to run helpful operations on Redis in production
|
||||
# This is targeted for internal usage and may not have all the necessary parameters
|
||||
@ -42,6 +46,19 @@ SCAN_ITER_COUNT = 10000
|
||||
BATCH_DEFAULT = 1000
|
||||
|
||||
|
||||
class OnyxRedisCommand(Enum):
|
||||
purge_connectorsync_taskset = "purge_connectorsync_taskset"
|
||||
purge_documentset_taskset = "purge_documentset_taskset"
|
||||
purge_usergroup_taskset = "purge_usergroup_taskset"
|
||||
purge_locks_blocking_deletion = "purge_locks_blocking_deletion"
|
||||
purge_vespa_syncing = "purge_vespa_syncing"
|
||||
get_user_token = "get_user_token"
|
||||
delete_user_token = "delete_user_token"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
def get_user_id(user_email: str) -> tuple[UUID, str]:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_email) if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
|
||||
@ -55,50 +72,79 @@ def get_user_id(user_email: str) -> tuple[UUID, str]:
|
||||
|
||||
|
||||
def onyx_redis(
|
||||
command: str,
|
||||
command: OnyxRedisCommand,
|
||||
batch: int,
|
||||
dry_run: bool,
|
||||
ssl: bool,
|
||||
host: str,
|
||||
port: int,
|
||||
db: int,
|
||||
password: str | None,
|
||||
user_email: str | None = None,
|
||||
cc_pair_id: int | None = None,
|
||||
) -> int:
|
||||
# this is global and not tenant aware
|
||||
pool = RedisPool.create_pool(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password if password else "",
|
||||
ssl=REDIS_SSL,
|
||||
ssl=ssl,
|
||||
ssl_cert_reqs="optional",
|
||||
ssl_ca_certs=None,
|
||||
)
|
||||
|
||||
r = Redis(connection_pool=pool)
|
||||
|
||||
logger.info("Redis ping starting. This may hang if your settings are incorrect.")
|
||||
|
||||
try:
|
||||
r.ping()
|
||||
except:
|
||||
logger.exception("Redis ping exceptioned")
|
||||
raise
|
||||
|
||||
if command == "purge_connectorsync_taskset":
|
||||
logger.info("Redis ping succeeded.")
|
||||
|
||||
if command == OnyxRedisCommand.purge_connectorsync_taskset:
|
||||
"""Purge connector tasksets. Used when the tasks represented in the tasksets
|
||||
have been purged."""
|
||||
return purge_by_match_and_type(
|
||||
"*connectorsync_taskset*", "set", batch, dry_run, r
|
||||
)
|
||||
elif command == "purge_documentset_taskset":
|
||||
elif command == OnyxRedisCommand.purge_documentset_taskset:
|
||||
return purge_by_match_and_type(
|
||||
"*documentset_taskset*", "set", batch, dry_run, r
|
||||
)
|
||||
elif command == "purge_usergroup_taskset":
|
||||
elif command == OnyxRedisCommand.purge_usergroup_taskset:
|
||||
return purge_by_match_and_type("*usergroup_taskset*", "set", batch, dry_run, r)
|
||||
elif command == "purge_vespa_syncing":
|
||||
elif command == OnyxRedisCommand.purge_locks_blocking_deletion:
|
||||
if cc_pair_id is None:
|
||||
logger.error("You must specify --cc-pair with purge_deletion_locks")
|
||||
return 1
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
logger.info(f"Purging locks associated with deleting cc_pair={cc_pair_id}.")
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
match_pattern = f"{tenant_id}:{RedisConnectorIndex.FENCE_PREFIX}_{cc_pair_id}/*"
|
||||
purge_by_match_and_type(match_pattern, "string", batch, dry_run, r)
|
||||
|
||||
redis_delete_if_exists_helper(
|
||||
f"{tenant_id}:{redis_connector.prune.fence_key}", dry_run, r
|
||||
)
|
||||
redis_delete_if_exists_helper(
|
||||
f"{tenant_id}:{redis_connector.permissions.fence_key}", dry_run, r
|
||||
)
|
||||
redis_delete_if_exists_helper(
|
||||
f"{tenant_id}:{redis_connector.external_group_sync.fence_key}", dry_run, r
|
||||
)
|
||||
return 0
|
||||
elif command == OnyxRedisCommand.purge_vespa_syncing:
|
||||
return purge_by_match_and_type(
|
||||
"*connectorsync:vespa_syncing*", "string", batch, dry_run, r
|
||||
)
|
||||
elif command == "get_user_token":
|
||||
elif command == OnyxRedisCommand.get_user_token:
|
||||
if not user_email:
|
||||
logger.error("You must specify --user-email with get_user_token")
|
||||
return 1
|
||||
@ -109,7 +155,7 @@ def onyx_redis(
|
||||
else:
|
||||
print(f"No token found for user {user_email}")
|
||||
return 2
|
||||
elif command == "delete_user_token":
|
||||
elif command == OnyxRedisCommand.delete_user_token:
|
||||
if not user_email:
|
||||
logger.error("You must specify --user-email with delete_user_token")
|
||||
return 1
|
||||
@ -131,6 +177,25 @@ def flush_batch_delete(batch_keys: list[bytes], r: Redis) -> None:
|
||||
pipe.execute()
|
||||
|
||||
|
||||
def redis_delete_if_exists_helper(key: str, dry_run: bool, r: Redis) -> bool:
|
||||
"""Returns True if the key was found, False if not.
|
||||
This function exists for logging purposes as the delete operation itself
|
||||
doesn't really need to check the existence of the key.
|
||||
"""
|
||||
|
||||
if not r.exists(key):
|
||||
logger.info(f"Did not find {key}.")
|
||||
return False
|
||||
|
||||
if dry_run:
|
||||
logger.info(f"(DRY-RUN) Deleting {key}.")
|
||||
else:
|
||||
logger.info(f"Deleting {key}.")
|
||||
r.delete(key)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def purge_by_match_and_type(
|
||||
match_pattern: str, match_type: str, batch_size: int, dry_run: bool, r: Redis
|
||||
) -> int:
|
||||
@ -138,6 +203,12 @@ def purge_by_match_and_type(
|
||||
match_type: https://redis.io/docs/latest/commands/type/
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
f"purge_by_match_and_type start: "
|
||||
f"match_pattern={match_pattern} "
|
||||
f"match_type={match_type}"
|
||||
)
|
||||
|
||||
# cursor = "0"
|
||||
# while cursor != 0:
|
||||
# cursor, data = self.scan(
|
||||
@ -164,13 +235,15 @@ def purge_by_match_and_type(
|
||||
logger.info(f"Deleting item {count}: {key_str}")
|
||||
|
||||
batch_keys.append(key)
|
||||
|
||||
# flush if batch size has been reached
|
||||
if len(batch_keys) >= batch_size:
|
||||
flush_batch_delete(batch_keys, r)
|
||||
batch_keys.clear()
|
||||
|
||||
if len(batch_keys) >= batch_size:
|
||||
flush_batch_delete(batch_keys, r)
|
||||
batch_keys.clear()
|
||||
# final flush
|
||||
flush_batch_delete(batch_keys, r)
|
||||
batch_keys.clear()
|
||||
|
||||
logger.info(f"Deleted {count} matches.")
|
||||
|
||||
@ -279,7 +352,21 @@ def delete_user_token_from_redis(
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Onyx Redis Manager")
|
||||
parser.add_argument("--command", type=str, help="Operation to run", required=True)
|
||||
parser.add_argument(
|
||||
"--command",
|
||||
type=OnyxRedisCommand,
|
||||
help="The command to run",
|
||||
choices=list(OnyxRedisCommand),
|
||||
required=True,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ssl",
|
||||
type=bool,
|
||||
default=REDIS_SSL,
|
||||
help="Use SSL when connecting to Redis. Usually True for prod and False for local testing",
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
@ -342,6 +429,13 @@ if __name__ == "__main__":
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cc-pair",
|
||||
type=int,
|
||||
help="A connector credential pair id. Used with the purge_deletion_locks command.",
|
||||
required=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.tenant_id:
|
||||
@ -368,10 +462,12 @@ if __name__ == "__main__":
|
||||
command=args.command,
|
||||
batch=args.batch,
|
||||
dry_run=args.dry_run,
|
||||
ssl=args.ssl,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
db=args.db,
|
||||
password=args.password,
|
||||
user_email=args.user_email,
|
||||
cc_pair_id=args.cc_pair,
|
||||
)
|
||||
sys.exit(exitcode)
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
|
||||
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
|
||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "127.0.0.1"
|
||||
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
||||
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
|
||||
MAX_DELAY = 45
|
||||
|
@ -30,8 +30,10 @@ class ConnectorManager:
|
||||
name=name,
|
||||
source=source,
|
||||
input_type=input_type,
|
||||
connector_specific_config=connector_specific_config
|
||||
or {"file_locations": []},
|
||||
connector_specific_config=(
|
||||
connector_specific_config
|
||||
or ({"file_locations": []} if source == DocumentSource.FILE else {})
|
||||
),
|
||||
access_type=access_type,
|
||||
groups=groups or [],
|
||||
)
|
||||
|
@ -88,8 +88,6 @@ class UserManager:
|
||||
if not session_cookie:
|
||||
raise Exception("Failed to login")
|
||||
|
||||
print(f"Logged in as {test_user.email}")
|
||||
|
||||
# Set cookies in the headers
|
||||
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
|
||||
test_user.cookies = {"fastapiusersauth": session_cookie}
|
||||
|
@ -70,7 +70,7 @@ def _answer_fixture_impl(
|
||||
files=[],
|
||||
single_message_history=None,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config),
|
||||
system_message=default_build_system_message(prompt_config, mock_llm.config),
|
||||
message_history=[],
|
||||
llm_config=mock_llm.config,
|
||||
raw_user_query=QUERY,
|
||||
|
131
backend/tests/unit/onyx/utils/test_threadpool_contextvars.py
Normal file
131
backend/tests/unit/onyx/utils/test_threadpool_contextvars.py
Normal file
@ -0,0 +1,131 @@
|
||||
import contextvars
|
||||
import time
|
||||
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
# Create a test contextvar
|
||||
test_var = contextvars.ContextVar("test_var", default="default")
|
||||
|
||||
|
||||
def get_contextvar_value() -> str:
|
||||
"""Helper function that runs in a thread and returns the contextvar value"""
|
||||
# Add a small sleep to ensure we're actually running in a different thread
|
||||
time.sleep(0.1)
|
||||
return test_var.get()
|
||||
|
||||
|
||||
def test_run_with_timeout_preserves_contextvar() -> None:
|
||||
"""Test that run_with_timeout preserves contextvar values"""
|
||||
# Set a value in the main thread
|
||||
test_var.set("test_value")
|
||||
|
||||
# Run function with timeout and verify the value is preserved
|
||||
result = run_with_timeout(1.0, get_contextvar_value)
|
||||
assert result == "test_value"
|
||||
|
||||
|
||||
def test_run_functions_in_parallel_preserves_contextvar() -> None:
|
||||
"""Test that run_functions_in_parallel preserves contextvar values"""
|
||||
# Set a value in the main thread
|
||||
test_var.set("parallel_test")
|
||||
|
||||
# Create multiple function calls
|
||||
function_calls = [
|
||||
FunctionCall(get_contextvar_value),
|
||||
FunctionCall(get_contextvar_value),
|
||||
]
|
||||
|
||||
# Run in parallel and verify all results have the correct value
|
||||
results = run_functions_in_parallel(function_calls)
|
||||
|
||||
for result_id, value in results.items():
|
||||
assert value == "parallel_test"
|
||||
|
||||
|
||||
def test_run_functions_tuples_preserves_contextvar() -> None:
|
||||
"""Test that run_functions_tuples_in_parallel preserves contextvar values"""
|
||||
# Set a value in the main thread
|
||||
test_var.set("tuple_test")
|
||||
|
||||
# Create list of function tuples
|
||||
functions_with_args = [
|
||||
(get_contextvar_value, ()),
|
||||
(get_contextvar_value, ()),
|
||||
]
|
||||
|
||||
# Run in parallel and verify all results have the correct value
|
||||
results = run_functions_tuples_in_parallel(functions_with_args)
|
||||
|
||||
for result in results:
|
||||
assert result == "tuple_test"
|
||||
|
||||
|
||||
def test_nested_contextvar_modifications() -> None:
|
||||
"""Test that modifications to contextvars in threads don't affect other threads"""
|
||||
|
||||
def modify_and_return_contextvar(new_value: str) -> tuple[str, str]:
|
||||
"""Helper that modifies the contextvar and returns both values"""
|
||||
original = test_var.get()
|
||||
test_var.set(new_value)
|
||||
time.sleep(0.1) # Ensure threads overlap
|
||||
return original, test_var.get()
|
||||
|
||||
# Set initial value
|
||||
test_var.set("initial")
|
||||
|
||||
# Run multiple functions that modify the contextvar
|
||||
functions_with_args = [
|
||||
(modify_and_return_contextvar, ("thread1",)),
|
||||
(modify_and_return_contextvar, ("thread2",)),
|
||||
]
|
||||
|
||||
results = run_functions_tuples_in_parallel(functions_with_args)
|
||||
|
||||
# Verify each thread saw the initial value and its own modification
|
||||
for original, modified in results:
|
||||
assert original == "initial" # Each thread should see the initial value
|
||||
assert modified in [
|
||||
"thread1",
|
||||
"thread2",
|
||||
] # Each thread should see its own modification
|
||||
|
||||
# Verify the main thread's value wasn't affected
|
||||
assert test_var.get() == "initial"
|
||||
|
||||
|
||||
def test_contextvar_isolation_between_runs() -> None:
|
||||
"""Test that contextvar changes don't leak between separate parallel runs"""
|
||||
|
||||
def set_and_return_contextvar(value: str) -> str:
|
||||
test_var.set(value)
|
||||
return test_var.get()
|
||||
|
||||
# First run
|
||||
test_var.set("first_run")
|
||||
first_results = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(set_and_return_contextvar, ("thread1",)),
|
||||
(set_and_return_contextvar, ("thread2",)),
|
||||
]
|
||||
)
|
||||
|
||||
# Verify first run results
|
||||
assert all(result in ["thread1", "thread2"] for result in first_results)
|
||||
|
||||
# Second run should still see the main thread's value
|
||||
assert test_var.get() == "first_run"
|
||||
|
||||
# Second run with different value
|
||||
test_var.set("second_run")
|
||||
second_results = run_functions_tuples_in_parallel(
|
||||
[
|
||||
(set_and_return_contextvar, ("thread3",)),
|
||||
(set_and_return_contextvar, ("thread4",)),
|
||||
]
|
||||
)
|
||||
|
||||
# Verify second run results
|
||||
assert all(result in ["thread3", "thread4"] for result in second_results)
|
@ -1,5 +1,5 @@
|
||||
# fill in the template
|
||||
envsubst '$SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/conf.d/$1" > /etc/nginx/conf.d/app.conf
|
||||
envsubst '$DOMAIN $SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/conf.d/$1" > /etc/nginx/conf.d/app.conf
|
||||
|
||||
# wait for the api_server to be ready
|
||||
echo "Waiting for API server to boot up; this may take a minute or two..."
|
||||
|
@ -36,6 +36,7 @@ services:
|
||||
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
|
||||
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
|
||||
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
|
||||
- INTEGRATION_TESTS_MODE=${INTEGRATION_TESTS_MODE:-}
|
||||
# Gen AI Settings
|
||||
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
|
@ -113,7 +113,6 @@ export function AssistantEditor({
|
||||
documentSets,
|
||||
user,
|
||||
defaultPublic,
|
||||
redirectType,
|
||||
llmProviders,
|
||||
tools,
|
||||
shouldAddAssistantToUserPreferences,
|
||||
@ -124,7 +123,6 @@ export function AssistantEditor({
|
||||
documentSets: DocumentSet[];
|
||||
user: User | null;
|
||||
defaultPublic: boolean;
|
||||
redirectType: SuccessfulPersonaUpdateRedirectType;
|
||||
llmProviders: FullLLMProvider[];
|
||||
tools: ToolSnapshot[];
|
||||
shouldAddAssistantToUserPreferences?: boolean;
|
||||
@ -502,7 +500,7 @@ export function AssistantEditor({
|
||||
)
|
||||
.map((message: { message: string; name?: string }) => ({
|
||||
message: message.message,
|
||||
name: message.name || message.message,
|
||||
name: message.message,
|
||||
}));
|
||||
|
||||
// don't set groups if marked as public
|
||||
|
@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import React, { useMemo, useState, useEffect } from "react";
|
||||
import { Formik, Form, Field } from "formik";
|
||||
import React, { useMemo } from "react";
|
||||
import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import {
|
||||
@ -13,17 +13,13 @@ import {
|
||||
createSlackChannelConfig,
|
||||
isPersonaASlackBotPersona,
|
||||
updateSlackChannelConfig,
|
||||
fetchSlackChannels,
|
||||
} from "../lib";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { SEARCH_TOOL_ID, SEARCH_TOOL_NAME } from "@/app/chat/tools/constants";
|
||||
import {
|
||||
SlackChannelConfigFormFields,
|
||||
SlackChannelConfigFormFieldsProps,
|
||||
} from "./SlackChannelConfigFormFields";
|
||||
import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
|
||||
import { SlackChannelConfigFormFields } from "./SlackChannelConfigFormFields";
|
||||
|
||||
export const SlackChannelConfigCreationForm = ({
|
||||
slack_bot_id,
|
||||
|
@ -1,13 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useEffect, useMemo } from "react";
|
||||
import {
|
||||
FieldArray,
|
||||
Form,
|
||||
useFormikContext,
|
||||
ErrorMessage,
|
||||
Field,
|
||||
} from "formik";
|
||||
import { FieldArray, useFormikContext, ErrorMessage, Field } from "formik";
|
||||
import { CCPairDescriptor, DocumentSet } from "@/lib/types";
|
||||
import {
|
||||
Label,
|
||||
@ -18,14 +12,13 @@ import {
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
|
||||
import CollapsibleSection from "@/app/admin/assistants/CollapsibleSection";
|
||||
import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { StandardAnswerCategoryDropdownField } from "@/components/standardAnswers/StandardAnswerCategoryDropdown";
|
||||
import { RadioGroup } from "@/components/ui/radio-group";
|
||||
import { RadioGroupItemField } from "@/components/ui/RadioGroupItemField";
|
||||
import { AlertCircle, View } from "lucide-react";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import {
|
||||
Tooltip,
|
||||
@ -50,6 +43,7 @@ import {
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
|
||||
import { CheckFormField } from "@/components/ui/CheckField";
|
||||
import { Input } from "@/components/ui/input";
|
||||
|
||||
export interface SlackChannelConfigFormFieldsProps {
|
||||
isUpdate: boolean;
|
||||
@ -178,9 +172,13 @@ export function SlackChannelConfigFormFields({
|
||||
);
|
||||
}, [documentSets]);
|
||||
|
||||
const { data: channelOptions, isLoading } = useSWR(
|
||||
const {
|
||||
data: channelOptions,
|
||||
error,
|
||||
isLoading,
|
||||
} = useSWR(
|
||||
`/api/manage/admin/slack-app/bots/${slack_bot_id}/channels`,
|
||||
async (url: string) => {
|
||||
async () => {
|
||||
const channels = await fetchSlackChannels(slack_bot_id);
|
||||
return channels.map((channel: any) => ({
|
||||
name: channel.name,
|
||||
@ -227,20 +225,34 @@ export function SlackChannelConfigFormFields({
|
||||
>
|
||||
Select A Slack Channel:
|
||||
</label>{" "}
|
||||
<Field name="channel_name">
|
||||
{({ field, form }: { field: any; form: any }) => (
|
||||
<SearchMultiSelectDropdown
|
||||
options={channelOptions || []}
|
||||
onSelect={(selected) => {
|
||||
form.setFieldValue("channel_name", selected.name);
|
||||
}}
|
||||
initialSearchTerm={field.value}
|
||||
onSearchTermChange={(term) => {
|
||||
form.setFieldValue("channel_name", term);
|
||||
}}
|
||||
{error ? (
|
||||
<div>
|
||||
<div className="text-red-600 text-sm mb-4">
|
||||
{error.message || "Unable to fetch Slack channels."}
|
||||
{" Please enter the channel name manually."}
|
||||
</div>
|
||||
<TextFormField
|
||||
name="channel_name"
|
||||
label="Channel Name"
|
||||
placeholder="Enter channel name"
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
</div>
|
||||
) : (
|
||||
<Field name="channel_name">
|
||||
{({ field, form }: { field: any; form: any }) => (
|
||||
<SearchMultiSelectDropdown
|
||||
options={channelOptions || []}
|
||||
onSelect={(selected) => {
|
||||
form.setFieldValue("channel_name", selected.name);
|
||||
}}
|
||||
initialSearchTerm={field.value}
|
||||
onSearchTermChange={(term) => {
|
||||
form.setFieldValue("channel_name", term);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Field>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
<div className="space-y-2 mt-4">
|
||||
|
@ -21,11 +21,7 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
|
||||
<div className="px-32">
|
||||
<div className="mx-auto container">
|
||||
<CardSection className="!border-none !bg-transparent !ring-none">
|
||||
<AssistantEditor
|
||||
{...values}
|
||||
defaultPublic={false}
|
||||
redirectType={SuccessfulPersonaUpdateRedirectType.CHAT}
|
||||
/>
|
||||
<AssistantEditor {...values} defaultPublic={false} />
|
||||
</CardSection>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -26,7 +26,6 @@ export default async function Page() {
|
||||
<AssistantEditor
|
||||
{...values}
|
||||
defaultPublic={false}
|
||||
redirectType={SuccessfulPersonaUpdateRedirectType.CHAT}
|
||||
shouldAddAssistantToUserPreferences={true}
|
||||
/>
|
||||
</CardSection>
|
||||
|
@ -47,6 +47,7 @@ import {
|
||||
removeMessage,
|
||||
sendMessage,
|
||||
setMessageAsLatest,
|
||||
updateLlmOverrideForChatSession,
|
||||
updateParentChildren,
|
||||
uploadFilesForChat,
|
||||
useScrollonStream,
|
||||
@ -65,7 +66,7 @@ import {
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
|
||||
import { useDocumentSelection } from "./useDocumentSelection";
|
||||
import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks";
|
||||
import { LlmDescriptor, useFilters, useLlmManager } from "@/lib/hooks";
|
||||
import { ChatState, FeedbackType, RegenerationState } from "./types";
|
||||
import { DocumentResults } from "./documentSidebar/DocumentResults";
|
||||
import { OnyxInitializingLoader } from "@/components/OnyxInitializingLoader";
|
||||
@ -89,7 +90,11 @@ import {
|
||||
import { buildFilters } from "@/lib/search/utils";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import Dropzone from "react-dropzone";
|
||||
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
|
||||
import {
|
||||
checkLLMSupportsImageInput,
|
||||
getFinalLLM,
|
||||
structureValue,
|
||||
} from "@/lib/llm/utils";
|
||||
import { ChatInputBar } from "./input/ChatInputBar";
|
||||
import { useChatContext } from "@/components/context/ChatContext";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
@ -194,16 +199,6 @@ export function ChatPage({
|
||||
return screenSize;
|
||||
}
|
||||
|
||||
const { height: screenHeight } = useScreenSize();
|
||||
|
||||
const getContainerHeight = () => {
|
||||
if (autoScrollEnabled) return undefined;
|
||||
|
||||
if (screenHeight < 600) return "20vh";
|
||||
if (screenHeight < 1200) return "30vh";
|
||||
return "40vh";
|
||||
};
|
||||
|
||||
// handle redirect if chat page is disabled
|
||||
// NOTE: this must be done here, in a client component since
|
||||
// settings are passed in via Context and therefore aren't
|
||||
@ -222,6 +217,7 @@ export function ChatPage({
|
||||
setProSearchEnabled(!proSearchEnabled);
|
||||
};
|
||||
|
||||
const isInitialLoad = useRef(true);
|
||||
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
|
||||
|
||||
const {
|
||||
@ -356,7 +352,7 @@ export function ChatPage({
|
||||
]
|
||||
);
|
||||
|
||||
const llmOverrideManager = useLlmOverride(
|
||||
const llmManager = useLlmManager(
|
||||
llmProviders,
|
||||
selectedChatSession,
|
||||
liveAssistant
|
||||
@ -520,8 +516,17 @@ export function ChatPage({
|
||||
scrollInitialized.current = false;
|
||||
|
||||
if (!hasPerformedInitialScroll) {
|
||||
if (isInitialLoad.current) {
|
||||
setHasPerformedInitialScroll(true);
|
||||
isInitialLoad.current = false;
|
||||
}
|
||||
clientScrollToBottom();
|
||||
|
||||
setTimeout(() => {
|
||||
setHasPerformedInitialScroll(true);
|
||||
}, 100);
|
||||
} else if (isChatSessionSwitch) {
|
||||
setHasPerformedInitialScroll(true);
|
||||
clientScrollToBottom(true);
|
||||
}
|
||||
|
||||
@ -1130,6 +1135,56 @@ export function ChatPage({
|
||||
});
|
||||
};
|
||||
const [uncaughtError, setUncaughtError] = useState<string | null>(null);
|
||||
const [agenticGenerating, setAgenticGenerating] = useState(false);
|
||||
|
||||
const autoScrollEnabled =
|
||||
(user?.preferences?.auto_scroll && !agenticGenerating) ?? false;
|
||||
|
||||
useScrollonStream({
|
||||
chatState: currentSessionChatState,
|
||||
scrollableDivRef,
|
||||
scrollDist,
|
||||
endDivRef,
|
||||
debounceNumber,
|
||||
mobile: settings?.isMobile,
|
||||
enableAutoScroll: autoScrollEnabled,
|
||||
});
|
||||
|
||||
// Track whether a message has been sent during this page load, keyed by chat session id
|
||||
const [sessionHasSentLocalUserMessage, setSessionHasSentLocalUserMessage] =
|
||||
useState<Map<string | null, boolean>>(new Map());
|
||||
|
||||
// Update the local state for a session once the user sends a message
|
||||
const markSessionMessageSent = (sessionId: string | null) => {
|
||||
setSessionHasSentLocalUserMessage((prev) => {
|
||||
const newMap = new Map(prev);
|
||||
newMap.set(sessionId, true);
|
||||
return newMap;
|
||||
});
|
||||
};
|
||||
const currentSessionHasSentLocalUserMessage = useMemo(
|
||||
() => (sessionId: string | null) => {
|
||||
return sessionHasSentLocalUserMessage.size === 0
|
||||
? undefined
|
||||
: sessionHasSentLocalUserMessage.get(sessionId) || false;
|
||||
},
|
||||
[sessionHasSentLocalUserMessage]
|
||||
);
|
||||
|
||||
const { height: screenHeight } = useScreenSize();
|
||||
|
||||
const getContainerHeight = useMemo(() => {
|
||||
return () => {
|
||||
if (!currentSessionHasSentLocalUserMessage(chatSessionIdRef.current)) {
|
||||
return undefined;
|
||||
}
|
||||
if (autoScrollEnabled) return undefined;
|
||||
|
||||
if (screenHeight < 600) return "40vh";
|
||||
if (screenHeight < 1200) return "50vh";
|
||||
return "60vh";
|
||||
};
|
||||
}, [autoScrollEnabled, screenHeight, currentSessionHasSentLocalUserMessage]);
|
||||
|
||||
const onSubmit = async ({
|
||||
messageIdToResend,
|
||||
@ -1138,7 +1193,7 @@ export function ChatPage({
|
||||
forceSearch,
|
||||
isSeededChat,
|
||||
alternativeAssistantOverride = null,
|
||||
modelOverRide,
|
||||
modelOverride,
|
||||
regenerationRequest,
|
||||
overrideFileDescriptors,
|
||||
}: {
|
||||
@ -1148,7 +1203,7 @@ export function ChatPage({
|
||||
forceSearch?: boolean;
|
||||
isSeededChat?: boolean;
|
||||
alternativeAssistantOverride?: Persona | null;
|
||||
modelOverRide?: LlmOverride;
|
||||
modelOverride?: LlmDescriptor;
|
||||
regenerationRequest?: RegenerationRequest | null;
|
||||
overrideFileDescriptors?: FileDescriptor[];
|
||||
} = {}) => {
|
||||
@ -1156,6 +1211,9 @@ export function ChatPage({
|
||||
let frozenSessionId = currentSessionId();
|
||||
updateCanContinue(false, frozenSessionId);
|
||||
|
||||
// Mark that we've sent a message for this session in the current page load
|
||||
markSessionMessageSent(frozenSessionId);
|
||||
|
||||
if (currentChatState() != "input") {
|
||||
if (currentChatState() == "uploading") {
|
||||
setPopup({
|
||||
@ -1191,6 +1249,22 @@ export function ChatPage({
|
||||
currChatSessionId = chatSessionIdRef.current as string;
|
||||
}
|
||||
frozenSessionId = currChatSessionId;
|
||||
// update the selected model for the chat session if one is specified so that
|
||||
// it persists across page reloads. Do not `await` here so that the message
|
||||
// request can continue and this will just happen in the background.
|
||||
// NOTE: only set the model override for the chat session once we send a
|
||||
// message with it. If the user switches models and then starts a new
|
||||
// chat session, it is unexpected for that model to be used when they
|
||||
// return to this session the next day.
|
||||
let finalLLM = modelOverride || llmManager.currentLlm;
|
||||
updateLlmOverrideForChatSession(
|
||||
currChatSessionId,
|
||||
structureValue(
|
||||
finalLLM.name || "",
|
||||
finalLLM.provider || "",
|
||||
finalLLM.modelName || ""
|
||||
)
|
||||
);
|
||||
|
||||
updateStatesWithNewSessionId(currChatSessionId);
|
||||
|
||||
@ -1250,11 +1324,14 @@ export function ChatPage({
|
||||
: null) ||
|
||||
(messageMap.size === 1 ? Array.from(messageMap.values())[0] : null);
|
||||
|
||||
const currentAssistantId = alternativeAssistantOverride
|
||||
? alternativeAssistantOverride.id
|
||||
: alternativeAssistant
|
||||
? alternativeAssistant.id
|
||||
: liveAssistant.id;
|
||||
let currentAssistantId;
|
||||
if (alternativeAssistantOverride) {
|
||||
currentAssistantId = alternativeAssistantOverride.id;
|
||||
} else if (alternativeAssistant) {
|
||||
currentAssistantId = alternativeAssistant.id;
|
||||
} else {
|
||||
currentAssistantId = liveAssistant.id;
|
||||
}
|
||||
|
||||
resetInputBar();
|
||||
let messageUpdates: Message[] | null = null;
|
||||
@ -1326,15 +1403,13 @@ export function ChatPage({
|
||||
forceSearch,
|
||||
regenerate: regenerationRequest !== undefined,
|
||||
modelProvider:
|
||||
modelOverRide?.name ||
|
||||
llmOverrideManager.llmOverride.name ||
|
||||
undefined,
|
||||
modelOverride?.name || llmManager.currentLlm.name || undefined,
|
||||
modelVersion:
|
||||
modelOverRide?.modelName ||
|
||||
llmOverrideManager.llmOverride.modelName ||
|
||||
modelOverride?.modelName ||
|
||||
llmManager.currentLlm.modelName ||
|
||||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
|
||||
undefined,
|
||||
temperature: llmOverrideManager.temperature || undefined,
|
||||
temperature: llmManager.temperature || undefined,
|
||||
systemPromptOverride:
|
||||
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
|
||||
useExistingUserMessage: isSeededChat,
|
||||
@ -1802,7 +1877,7 @@ export function ChatPage({
|
||||
const [_, llmModel] = getFinalLLM(
|
||||
llmProviders,
|
||||
liveAssistant,
|
||||
llmOverrideManager.llmOverride
|
||||
llmManager.currentLlm
|
||||
);
|
||||
const llmAcceptsImages = checkLLMSupportsImageInput(llmModel);
|
||||
|
||||
@ -1857,7 +1932,6 @@ export function ChatPage({
|
||||
// Used to maintain a "time out" for history sidebar so our existing refs can have time to process change
|
||||
const [untoggled, setUntoggled] = useState(false);
|
||||
const [loadingError, setLoadingError] = useState<string | null>(null);
|
||||
const [agenticGenerating, setAgenticGenerating] = useState(false);
|
||||
|
||||
const explicitlyUntoggle = () => {
|
||||
setShowHistorySidebar(false);
|
||||
@ -1899,19 +1973,6 @@ export function ChatPage({
|
||||
isAnonymousUser: user?.is_anonymous_user,
|
||||
});
|
||||
|
||||
const autoScrollEnabled =
|
||||
(user?.preferences?.auto_scroll && !agenticGenerating) ?? false;
|
||||
|
||||
useScrollonStream({
|
||||
chatState: currentSessionChatState,
|
||||
scrollableDivRef,
|
||||
scrollDist,
|
||||
endDivRef,
|
||||
debounceNumber,
|
||||
mobile: settings?.isMobile,
|
||||
enableAutoScroll: autoScrollEnabled,
|
||||
});
|
||||
|
||||
// Virtualization + Scrolling related effects and functions
|
||||
const scrollInitialized = useRef(false);
|
||||
interface VisibleRange {
|
||||
@ -2121,7 +2182,7 @@ export function ChatPage({
|
||||
}, [searchParams, router]);
|
||||
|
||||
useEffect(() => {
|
||||
llmOverrideManager.updateImageFilesPresent(imageFileInMessageHistory);
|
||||
llmManager.updateImageFilesPresent(imageFileInMessageHistory);
|
||||
}, [imageFileInMessageHistory]);
|
||||
|
||||
const pathname = usePathname();
|
||||
@ -2175,9 +2236,9 @@ export function ChatPage({
|
||||
|
||||
function createRegenerator(regenerationRequest: RegenerationRequest) {
|
||||
// Returns new function that only needs `modelOverRide` to be specified when called
|
||||
return async function (modelOverRide: LlmOverride) {
|
||||
return async function (modelOverride: LlmDescriptor) {
|
||||
return await onSubmit({
|
||||
modelOverRide,
|
||||
modelOverride,
|
||||
messageIdToResend: regenerationRequest.parentMessage.messageId,
|
||||
regenerationRequest,
|
||||
forceSearch: regenerationRequest.forceSearch,
|
||||
@ -2258,9 +2319,7 @@ export function ChatPage({
|
||||
{(settingsToggled || userSettingsToggled) && (
|
||||
<UserSettingsModal
|
||||
setPopup={setPopup}
|
||||
setLlmOverride={(newOverride) =>
|
||||
llmOverrideManager.updateLLMOverride(newOverride)
|
||||
}
|
||||
setCurrentLlm={(newLlm) => llmManager.updateCurrentLlm(newLlm)}
|
||||
defaultModel={user?.preferences.default_model!}
|
||||
llmProviders={llmProviders}
|
||||
onClose={() => {
|
||||
@ -2324,7 +2383,7 @@ export function ChatPage({
|
||||
<ShareChatSessionModal
|
||||
assistantId={liveAssistant?.id}
|
||||
message={message}
|
||||
modelOverride={llmOverrideManager.llmOverride}
|
||||
modelOverride={llmManager.currentLlm}
|
||||
chatSessionId={sharedChatSession.id}
|
||||
existingSharedStatus={sharedChatSession.shared_status}
|
||||
onClose={() => setSharedChatSession(null)}
|
||||
@ -2342,7 +2401,7 @@ export function ChatPage({
|
||||
<ShareChatSessionModal
|
||||
message={message}
|
||||
assistantId={liveAssistant?.id}
|
||||
modelOverride={llmOverrideManager.llmOverride}
|
||||
modelOverride={llmManager.currentLlm}
|
||||
chatSessionId={chatSessionIdRef.current}
|
||||
existingSharedStatus={chatSessionSharedStatus}
|
||||
onClose={() => setSharingModalVisible(false)}
|
||||
@ -2572,6 +2631,7 @@ export function ChatPage({
|
||||
style={{ overflowAnchor: "none" }}
|
||||
key={currentSessionId()}
|
||||
className={
|
||||
(hasPerformedInitialScroll ? "" : " hidden ") +
|
||||
"desktop:-ml-4 w-full mx-auto " +
|
||||
"absolute mobile:top-0 desktop:top-0 left-0 " +
|
||||
(settings?.enterpriseSettings
|
||||
@ -3058,7 +3118,7 @@ export function ChatPage({
|
||||
messageId: message.messageId,
|
||||
parentMessage: parentMessage!,
|
||||
forceSearch: true,
|
||||
})(llmOverrideManager.llmOverride);
|
||||
})(llmManager.currentLlm);
|
||||
} else {
|
||||
setPopup({
|
||||
type: "error",
|
||||
@ -3203,7 +3263,7 @@ export function ChatPage({
|
||||
availableDocumentSets={documentSets}
|
||||
availableTags={tags}
|
||||
filterManager={filterManager}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
llmManager={llmManager}
|
||||
removeDocs={() => {
|
||||
clearSelectedDocuments();
|
||||
}}
|
||||
|
@ -1,8 +1,8 @@
|
||||
import { useChatContext } from "@/components/context/ChatContext";
|
||||
import {
|
||||
getDisplayNameForModel,
|
||||
LlmOverride,
|
||||
useLlmOverride,
|
||||
LlmDescriptor,
|
||||
useLlmManager,
|
||||
} from "@/lib/hooks";
|
||||
import { StringOrNumberOption } from "@/components/Dropdown";
|
||||
|
||||
@ -106,13 +106,13 @@ export default function RegenerateOption({
|
||||
onDropdownVisibleChange,
|
||||
}: {
|
||||
selectedAssistant: Persona;
|
||||
regenerate: (modelOverRide: LlmOverride) => Promise<void>;
|
||||
regenerate: (modelOverRide: LlmDescriptor) => Promise<void>;
|
||||
overriddenModel?: string;
|
||||
onHoverChange: (isHovered: boolean) => void;
|
||||
onDropdownVisibleChange: (isVisible: boolean) => void;
|
||||
}) {
|
||||
const { llmProviders } = useChatContext();
|
||||
const llmOverrideManager = useLlmOverride(llmProviders);
|
||||
const llmManager = useLlmManager(llmProviders);
|
||||
|
||||
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
|
||||
|
||||
@ -148,7 +148,7 @@ export default function RegenerateOption({
|
||||
);
|
||||
|
||||
const currentModelName =
|
||||
llmOverrideManager?.llmOverride.modelName ||
|
||||
llmManager?.currentLlm.modelName ||
|
||||
(selectedAssistant
|
||||
? selectedAssistant.llm_model_version_override || llmName
|
||||
: llmName);
|
||||
|
@ -6,7 +6,7 @@ import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import LLMPopover from "./LLMPopover";
|
||||
import { InputPrompt } from "@/app/chat/interfaces";
|
||||
|
||||
import { FilterManager, LlmOverrideManager } from "@/lib/hooks";
|
||||
import { FilterManager, LlmManager } from "@/lib/hooks";
|
||||
import { useChatContext } from "@/components/context/ChatContext";
|
||||
import { ChatFileType, FileDescriptor } from "../interfaces";
|
||||
import {
|
||||
@ -180,7 +180,7 @@ interface ChatInputBarProps {
|
||||
setMessage: (message: string) => void;
|
||||
stopGenerating: () => void;
|
||||
onSubmit: () => void;
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
llmManager: LlmManager;
|
||||
chatState: ChatState;
|
||||
alternativeAssistant: Persona | null;
|
||||
// assistants
|
||||
@ -225,7 +225,7 @@ export function ChatInputBar({
|
||||
availableSources,
|
||||
availableDocumentSets,
|
||||
availableTags,
|
||||
llmOverrideManager,
|
||||
llmManager,
|
||||
proSearchEnabled,
|
||||
setProSearchEnabled,
|
||||
}: ChatInputBarProps) {
|
||||
@ -781,7 +781,7 @@ export function ChatInputBar({
|
||||
|
||||
<LLMPopover
|
||||
llmProviders={llmProviders}
|
||||
llmOverrideManager={llmOverrideManager}
|
||||
llmManager={llmManager}
|
||||
requiresImageGeneration={false}
|
||||
currentAssistant={selectedAssistant}
|
||||
/>
|
||||
|
@ -16,7 +16,7 @@ import {
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { LlmOverrideManager } from "@/lib/hooks";
|
||||
import { LlmManager } from "@/lib/hooks";
|
||||
|
||||
import {
|
||||
Tooltip,
|
||||
@ -31,21 +31,19 @@ import { useUser } from "@/components/user/UserProvider";
|
||||
|
||||
interface LLMPopoverProps {
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
llmManager: LlmManager;
|
||||
requiresImageGeneration?: boolean;
|
||||
currentAssistant?: Persona;
|
||||
}
|
||||
|
||||
export default function LLMPopover({
|
||||
llmProviders,
|
||||
llmOverrideManager,
|
||||
llmManager,
|
||||
requiresImageGeneration,
|
||||
currentAssistant,
|
||||
}: LLMPopoverProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { user } = useUser();
|
||||
const { llmOverride, updateLLMOverride } = llmOverrideManager;
|
||||
const currentLlm = llmOverride.modelName;
|
||||
|
||||
const llmOptionsByProvider: {
|
||||
[provider: string]: {
|
||||
@ -93,19 +91,19 @@ export default function LLMPopover({
|
||||
: null;
|
||||
|
||||
const [localTemperature, setLocalTemperature] = useState(
|
||||
llmOverrideManager.temperature ?? 0.5
|
||||
llmManager.temperature ?? 0.5
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setLocalTemperature(llmOverrideManager.temperature ?? 0.5);
|
||||
}, [llmOverrideManager.temperature]);
|
||||
setLocalTemperature(llmManager.temperature ?? 0.5);
|
||||
}, [llmManager.temperature]);
|
||||
|
||||
const handleTemperatureChange = (value: number[]) => {
|
||||
setLocalTemperature(value[0]);
|
||||
};
|
||||
|
||||
const handleTemperatureChangeComplete = (value: number[]) => {
|
||||
llmOverrideManager.updateTemperature(value[0]);
|
||||
llmManager.updateTemperature(value[0]);
|
||||
};
|
||||
|
||||
return (
|
||||
@ -120,15 +118,15 @@ export default function LLMPopover({
|
||||
toggle
|
||||
flexPriority="stiff"
|
||||
name={getDisplayNameForModel(
|
||||
llmOverrideManager?.llmOverride.modelName ||
|
||||
llmManager?.currentLlm.modelName ||
|
||||
defaultModelDisplayName ||
|
||||
"Models"
|
||||
)}
|
||||
Icon={getProviderIcon(
|
||||
llmOverrideManager?.llmOverride.provider ||
|
||||
llmManager?.currentLlm.provider ||
|
||||
defaultProvider?.provider ||
|
||||
"anthropic",
|
||||
llmOverrideManager?.llmOverride.modelName ||
|
||||
llmManager?.currentLlm.modelName ||
|
||||
defaultProvider?.default_model_name ||
|
||||
"claude-3-5-sonnet-20240620"
|
||||
)}
|
||||
@ -147,12 +145,12 @@ export default function LLMPopover({
|
||||
<button
|
||||
key={index}
|
||||
className={`w-full flex items-center gap-x-2 px-3 py-2 text-sm text-left hover:bg-background-100 dark:hover:bg-neutral-800 transition-colors duration-150 ${
|
||||
currentLlm === name
|
||||
llmManager.currentLlm.modelName === name
|
||||
? "bg-background-100 dark:bg-neutral-900 text-text"
|
||||
: "text-text-darker"
|
||||
}`}
|
||||
onClick={() => {
|
||||
updateLLMOverride(destructureValue(value));
|
||||
llmManager.updateCurrentLlm(destructureValue(value));
|
||||
setIsOpen(false);
|
||||
}}
|
||||
>
|
||||
@ -172,7 +170,7 @@ export default function LLMPopover({
|
||||
);
|
||||
}
|
||||
})()}
|
||||
{llmOverrideManager.imageFilesPresent &&
|
||||
{llmManager.imageFilesPresent &&
|
||||
!checkLLMSupportsImageInput(name) && (
|
||||
<TooltipProvider>
|
||||
<Tooltip delayDuration={0}>
|
||||
@ -199,7 +197,7 @@ export default function LLMPopover({
|
||||
<div className="w-full px-3 py-2">
|
||||
<Slider
|
||||
value={[localTemperature]}
|
||||
max={llmOverrideManager.maxTemperature}
|
||||
max={llmManager.maxTemperature}
|
||||
min={0}
|
||||
step={0.01}
|
||||
onValueChange={handleTemperatureChange}
|
||||
|
@ -65,7 +65,7 @@ export function getChatRetentionInfo(
|
||||
};
|
||||
}
|
||||
|
||||
export async function updateModelOverrideForChatSession(
|
||||
export async function updateLlmOverrideForChatSession(
|
||||
chatSessionId: string,
|
||||
newAlternateModel: string
|
||||
) {
|
||||
@ -236,7 +236,7 @@ export async function* sendMessage({
|
||||
}
|
||||
: null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
use_agentic_search: useLanggraph,
|
||||
use_agentic_search: useLanggraph ?? false,
|
||||
});
|
||||
|
||||
const response = await fetch(`/api/chat/send-message`, {
|
||||
|
@ -44,7 +44,7 @@ import { ValidSources } from "@/lib/types";
|
||||
import { useMouseTracking } from "./hooks";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import RegenerateOption from "../RegenerateOption";
|
||||
import { LlmOverride } from "@/lib/hooks";
|
||||
import { LlmDescriptor } from "@/lib/hooks";
|
||||
import { ContinueGenerating } from "./ContinueMessage";
|
||||
import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents";
|
||||
import { extractCodeText, preprocessLaTeX } from "./codeUtils";
|
||||
@ -117,7 +117,7 @@ export const AgenticMessage = ({
|
||||
isComplete?: boolean;
|
||||
handleFeedback?: (feedbackType: FeedbackType) => void;
|
||||
overriddenModel?: string;
|
||||
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
|
||||
regenerate?: (modelOverRide: LlmDescriptor) => Promise<void>;
|
||||
setPresentingDocument?: (document: OnyxDocument) => void;
|
||||
toggleDocDisplay?: (agentic: boolean) => void;
|
||||
error?: string | null;
|
||||
|
@ -58,7 +58,7 @@ import { useMouseTracking } from "./hooks";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
|
||||
import RegenerateOption from "../RegenerateOption";
|
||||
import { LlmOverride } from "@/lib/hooks";
|
||||
import { LlmDescriptor } from "@/lib/hooks";
|
||||
import { ContinueGenerating } from "./ContinueMessage";
|
||||
import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents";
|
||||
import { extractCodeText, preprocessLaTeX } from "./codeUtils";
|
||||
@ -213,7 +213,7 @@ export const AIMessage = ({
|
||||
handleForceSearch?: () => void;
|
||||
retrievalDisabled?: boolean;
|
||||
overriddenModel?: string;
|
||||
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
|
||||
regenerate?: (modelOverRide: LlmDescriptor) => Promise<void>;
|
||||
setPresentingDocument: (document: OnyxDocument) => void;
|
||||
removePadding?: boolean;
|
||||
}) => {
|
||||
|
@ -11,7 +11,7 @@ import { CopyButton } from "@/components/CopyButton";
|
||||
import { SEARCH_PARAM_NAMES } from "../searchParams";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { structureValue } from "@/lib/llm/utils";
|
||||
import { LlmOverride } from "@/lib/hooks";
|
||||
import { LlmDescriptor } from "@/lib/hooks";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
|
||||
@ -38,7 +38,7 @@ async function generateShareLink(chatSessionId: string) {
|
||||
async function generateSeedLink(
|
||||
message?: string,
|
||||
assistantId?: number,
|
||||
modelOverride?: LlmOverride
|
||||
modelOverride?: LlmDescriptor
|
||||
) {
|
||||
const baseUrl = `${window.location.protocol}//${window.location.host}`;
|
||||
const model = modelOverride
|
||||
@ -92,7 +92,7 @@ export function ShareChatSessionModal({
|
||||
onClose: () => void;
|
||||
message?: string;
|
||||
assistantId?: number;
|
||||
modelOverride?: LlmOverride;
|
||||
modelOverride?: LlmDescriptor;
|
||||
}) {
|
||||
const [shareLink, setShareLink] = useState<string>(
|
||||
existingSharedStatus === ChatSessionSharedStatus.Public
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { useContext, useEffect, useRef, useState } from "react";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks";
|
||||
import { getDisplayNameForModel, LlmDescriptor } from "@/lib/hooks";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
|
||||
import { destructureValue, structureValue } from "@/lib/llm/utils";
|
||||
@ -31,12 +31,12 @@ export function UserSettingsModal({
|
||||
setPopup,
|
||||
llmProviders,
|
||||
onClose,
|
||||
setLlmOverride,
|
||||
setCurrentLlm,
|
||||
defaultModel,
|
||||
}: {
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
setLlmOverride?: (newOverride: LlmOverride) => void;
|
||||
setCurrentLlm?: (newLlm: LlmDescriptor) => void;
|
||||
onClose: () => void;
|
||||
defaultModel: string | null;
|
||||
}) {
|
||||
@ -127,18 +127,14 @@ export function UserSettingsModal({
|
||||
);
|
||||
});
|
||||
|
||||
const llmOptions = Object.entries(llmOptionsByProvider).flatMap(
|
||||
([provider, options]) => [...options]
|
||||
);
|
||||
|
||||
const router = useRouter();
|
||||
const handleChangedefaultModel = async (defaultModel: string | null) => {
|
||||
try {
|
||||
const response = await setUserDefaultModel(defaultModel);
|
||||
|
||||
if (response.ok) {
|
||||
if (defaultModel && setLlmOverride) {
|
||||
setLlmOverride(destructureValue(defaultModel));
|
||||
if (defaultModel && setCurrentLlm) {
|
||||
setCurrentLlm(destructureValue(defaultModel));
|
||||
}
|
||||
setPopup({
|
||||
message: "Default model updated successfully",
|
||||
|
@ -95,7 +95,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
|
||||
}
|
||||
}
|
||||
|
||||
if (enterpriseSettings && settings.pro_search_enabled == null) {
|
||||
if (settings.pro_search_enabled == null) {
|
||||
settings.pro_search_enabled = true;
|
||||
}
|
||||
|
||||
|
@ -360,18 +360,18 @@ export const useUsers = ({ includeApiKeys }: UseUsersParams) => {
|
||||
};
|
||||
};
|
||||
|
||||
export interface LlmOverride {
|
||||
export interface LlmDescriptor {
|
||||
name: string;
|
||||
provider: string;
|
||||
modelName: string;
|
||||
}
|
||||
|
||||
export interface LlmOverrideManager {
|
||||
llmOverride: LlmOverride;
|
||||
updateLLMOverride: (newOverride: LlmOverride) => void;
|
||||
export interface LlmManager {
|
||||
currentLlm: LlmDescriptor;
|
||||
updateCurrentLlm: (newOverride: LlmDescriptor) => void;
|
||||
temperature: number;
|
||||
updateTemperature: (temperature: number) => void;
|
||||
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
||||
updateModelOverrideBasedOnChatSession: (chatSession?: ChatSession) => void;
|
||||
imageFilesPresent: boolean;
|
||||
updateImageFilesPresent: (present: boolean) => void;
|
||||
liveAssistant: Persona | null;
|
||||
@ -400,7 +400,7 @@ Thus, the input should be
|
||||
|
||||
Changes take place as
|
||||
- liveAssistant or currentChatSession changes (and the associated model override is set)
|
||||
- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
|
||||
- (updateCurrentLlm) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
|
||||
|
||||
If we have a live assistant, we should use that model override
|
||||
|
||||
@ -419,55 +419,78 @@ This approach ensures that user preferences are maintained for existing chats wh
|
||||
providing appropriate defaults for new conversations based on the available tools.
|
||||
*/
|
||||
|
||||
export function useLlmOverride(
|
||||
export function useLlmManager(
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
currentChatSession?: ChatSession,
|
||||
liveAssistant?: Persona
|
||||
): LlmOverrideManager {
|
||||
): LlmManager {
|
||||
const { user } = useUser();
|
||||
|
||||
const [userHasManuallyOverriddenLLM, setUserHasManuallyOverriddenLLM] =
|
||||
useState(false);
|
||||
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
||||
const [currentLlm, setCurrentLlm] = useState<LlmDescriptor>({
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
});
|
||||
|
||||
const llmOverrideUpdate = () => {
|
||||
if (liveAssistant?.llm_model_version_override) {
|
||||
setLlmOverride(
|
||||
getValidLlmOverride(liveAssistant.llm_model_version_override)
|
||||
);
|
||||
} else if (currentChatSession?.current_alternate_model) {
|
||||
setLlmOverride(
|
||||
getValidLlmOverride(currentChatSession.current_alternate_model)
|
||||
);
|
||||
} else if (user?.preferences?.default_model) {
|
||||
setLlmOverride(getValidLlmOverride(user.preferences.default_model));
|
||||
return;
|
||||
} else {
|
||||
const defaultProvider = llmProviders.find(
|
||||
(provider) => provider.is_default_provider
|
||||
);
|
||||
const llmUpdate = () => {
|
||||
/* Should be called when the live assistant or current chat session changes */
|
||||
|
||||
if (defaultProvider) {
|
||||
setLlmOverride({
|
||||
name: defaultProvider.name,
|
||||
provider: defaultProvider.provider,
|
||||
modelName: defaultProvider.default_model_name,
|
||||
});
|
||||
// separate function so we can `return` to break out
|
||||
const _llmUpdate = () => {
|
||||
// if the user has overridden in this session and just switched to a brand
|
||||
// new session, use their manually specified model
|
||||
if (userHasManuallyOverriddenLLM && !currentChatSession) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (currentChatSession?.current_alternate_model) {
|
||||
setCurrentLlm(
|
||||
getValidLlmDescriptor(currentChatSession.current_alternate_model)
|
||||
);
|
||||
} else if (liveAssistant?.llm_model_version_override) {
|
||||
setCurrentLlm(
|
||||
getValidLlmDescriptor(liveAssistant.llm_model_version_override)
|
||||
);
|
||||
} else if (userHasManuallyOverriddenLLM) {
|
||||
// if the user has an override and there's nothing special about the
|
||||
// current chat session, use the override
|
||||
return;
|
||||
} else if (user?.preferences?.default_model) {
|
||||
setCurrentLlm(getValidLlmDescriptor(user.preferences.default_model));
|
||||
} else {
|
||||
const defaultProvider = llmProviders.find(
|
||||
(provider) => provider.is_default_provider
|
||||
);
|
||||
|
||||
if (defaultProvider) {
|
||||
setCurrentLlm({
|
||||
name: defaultProvider.name,
|
||||
provider: defaultProvider.provider,
|
||||
modelName: defaultProvider.default_model_name,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
_llmUpdate();
|
||||
setChatSession(currentChatSession || null);
|
||||
};
|
||||
|
||||
const getValidLlmOverride = (
|
||||
overrideModel: string | null | undefined
|
||||
): LlmOverride => {
|
||||
if (overrideModel) {
|
||||
const model = destructureValue(overrideModel);
|
||||
const getValidLlmDescriptor = (
|
||||
modelName: string | null | undefined
|
||||
): LlmDescriptor => {
|
||||
if (modelName) {
|
||||
const model = destructureValue(modelName);
|
||||
if (!(model.modelName && model.modelName.length > 0)) {
|
||||
const provider = llmProviders.find((p) =>
|
||||
p.model_names.includes(overrideModel)
|
||||
p.model_names.includes(modelName)
|
||||
);
|
||||
if (provider) {
|
||||
return {
|
||||
modelName: overrideModel,
|
||||
modelName: modelName,
|
||||
name: provider.name,
|
||||
provider: provider.provider,
|
||||
};
|
||||
@ -491,38 +514,32 @@ export function useLlmOverride(
|
||||
setImageFilesPresent(present);
|
||||
};
|
||||
|
||||
const [llmOverride, setLlmOverride] = useState<LlmOverride>({
|
||||
name: "",
|
||||
provider: "",
|
||||
modelName: "",
|
||||
});
|
||||
|
||||
// Manually set the override
|
||||
const updateLLMOverride = (newOverride: LlmOverride) => {
|
||||
// Manually set the LLM
|
||||
const updateCurrentLlm = (newLlm: LlmDescriptor) => {
|
||||
const provider =
|
||||
newOverride.provider ||
|
||||
findProviderForModel(llmProviders, newOverride.modelName);
|
||||
newLlm.provider || findProviderForModel(llmProviders, newLlm.modelName);
|
||||
const structuredValue = structureValue(
|
||||
newOverride.name,
|
||||
newLlm.name,
|
||||
provider,
|
||||
newOverride.modelName
|
||||
newLlm.modelName
|
||||
);
|
||||
setLlmOverride(getValidLlmOverride(structuredValue));
|
||||
setCurrentLlm(getValidLlmDescriptor(structuredValue));
|
||||
setUserHasManuallyOverriddenLLM(true);
|
||||
};
|
||||
|
||||
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
|
||||
const updateModelOverrideBasedOnChatSession = (chatSession?: ChatSession) => {
|
||||
if (chatSession && chatSession.current_alternate_model?.length > 0) {
|
||||
setLlmOverride(getValidLlmOverride(chatSession.current_alternate_model));
|
||||
setCurrentLlm(getValidLlmDescriptor(chatSession.current_alternate_model));
|
||||
}
|
||||
};
|
||||
|
||||
const [temperature, setTemperature] = useState<number>(() => {
|
||||
llmOverrideUpdate();
|
||||
llmUpdate();
|
||||
|
||||
if (currentChatSession?.current_temperature_override != null) {
|
||||
return Math.min(
|
||||
currentChatSession.current_temperature_override,
|
||||
isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0
|
||||
isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0
|
||||
);
|
||||
} else if (
|
||||
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
|
||||
@ -533,22 +550,23 @@ export function useLlmOverride(
|
||||
});
|
||||
|
||||
const maxTemperature = useMemo(() => {
|
||||
return isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0;
|
||||
}, [llmOverride]);
|
||||
return isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0;
|
||||
}, [currentLlm]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||
if (isAnthropic(currentLlm.provider, currentLlm.modelName)) {
|
||||
const newTemperature = Math.min(temperature, 1.0);
|
||||
setTemperature(newTemperature);
|
||||
if (chatSession?.id) {
|
||||
updateTemperatureOverrideForChatSession(chatSession.id, newTemperature);
|
||||
}
|
||||
}
|
||||
}, [llmOverride]);
|
||||
}, [currentLlm]);
|
||||
|
||||
useEffect(() => {
|
||||
llmUpdate();
|
||||
|
||||
if (!chatSession && currentChatSession) {
|
||||
setChatSession(currentChatSession || null);
|
||||
if (temperature) {
|
||||
updateTemperatureOverrideForChatSession(
|
||||
currentChatSession.id,
|
||||
@ -570,7 +588,7 @@ export function useLlmOverride(
|
||||
}, [liveAssistant, currentChatSession]);
|
||||
|
||||
const updateTemperature = (temperature: number) => {
|
||||
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||
if (isAnthropic(currentLlm.provider, currentLlm.modelName)) {
|
||||
setTemperature((prevTemp) => Math.min(temperature, 1.0));
|
||||
} else {
|
||||
setTemperature(temperature);
|
||||
@ -581,9 +599,9 @@ export function useLlmOverride(
|
||||
};
|
||||
|
||||
return {
|
||||
updateModelOverrideForChatSession,
|
||||
llmOverride,
|
||||
updateLLMOverride,
|
||||
updateModelOverrideBasedOnChatSession,
|
||||
currentLlm,
|
||||
updateCurrentLlm,
|
||||
temperature,
|
||||
updateTemperature,
|
||||
imageFilesPresent,
|
||||
|
@ -1,11 +1,11 @@
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LlmOverride } from "@/lib/hooks";
|
||||
import { LlmDescriptor } from "@/lib/hooks";
|
||||
|
||||
export function getFinalLLM(
|
||||
llmProviders: LLMProviderDescriptor[],
|
||||
persona: Persona | null,
|
||||
llmOverride: LlmOverride | null
|
||||
currentLlm: LlmDescriptor | null
|
||||
): [string, string] {
|
||||
const defaultProvider = llmProviders.find(
|
||||
(llmProvider) => llmProvider.is_default_provider
|
||||
@ -26,9 +26,9 @@ export function getFinalLLM(
|
||||
model = persona.llm_model_version_override || model;
|
||||
}
|
||||
|
||||
if (llmOverride) {
|
||||
provider = llmOverride.provider || provider;
|
||||
model = llmOverride.modelName || model;
|
||||
if (currentLlm) {
|
||||
provider = currentLlm.provider || provider;
|
||||
model = currentLlm.modelName || model;
|
||||
}
|
||||
|
||||
return [provider, model];
|
||||
@ -37,7 +37,7 @@ export function getFinalLLM(
|
||||
export function getLLMProviderOverrideForPersona(
|
||||
liveAssistant: Persona,
|
||||
llmProviders: LLMProviderDescriptor[]
|
||||
): LlmOverride | null {
|
||||
): LlmDescriptor | null {
|
||||
const overrideProvider = liveAssistant.llm_model_provider_override;
|
||||
const overrideModel = liveAssistant.llm_model_version_override;
|
||||
|
||||
@ -135,7 +135,7 @@ export const structureValue = (
|
||||
return `${name}__${provider}__${modelName}`;
|
||||
};
|
||||
|
||||
export const destructureValue = (value: string): LlmOverride => {
|
||||
export const destructureValue = (value: string): LlmDescriptor => {
|
||||
const [displayName, provider, modelName] = value.split("__");
|
||||
return {
|
||||
name: displayName,
|
||||
|
@ -1,5 +1,3 @@
|
||||
import { LlmOverride } from "../hooks";
|
||||
|
||||
export async function setUserDefaultModel(
|
||||
model: string | null
|
||||
): Promise<Response> {
|
||||
|
Loading…
x
Reference in New Issue
Block a user