mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-18 05:41:58 +01:00
Update litellm to fix bedrock models (#2649)
This commit is contained in:
parent
fffb9c155a
commit
b8232e0681
.github/workflows
backend
58
.github/workflows/pr-python-model-tests.yml
vendored
Normal file
58
.github/workflows/pr-python-model-tests.yml
vendored
Normal file
@ -0,0 +1,58 @@
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
env:
|
||||
# Bedrock
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
@ -290,10 +290,12 @@ class DefaultMultiLLM(LLM):
|
||||
return litellm.completion(
|
||||
# model choice
|
||||
model=f"{self.config.model_provider}/{self.config.model_name}",
|
||||
api_key=self._api_key,
|
||||
base_url=self._api_base,
|
||||
api_version=self._api_version,
|
||||
custom_llm_provider=self._custom_llm_provider,
|
||||
# NOTE: have to pass in None instead of empty string for these
|
||||
# otherwise litellm can have some issues with bedrock
|
||||
api_key=self._api_key or None,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
# actual input
|
||||
messages=prompt,
|
||||
tools=tools,
|
||||
|
@ -28,7 +28,7 @@ jsonref==1.1.0
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.47.1
|
||||
litellm==1.48.7
|
||||
llama-index==0.9.45
|
||||
Mako==1.2.4
|
||||
msal==1.28.0
|
||||
|
24
backend/tests/daily/conftest.py
Normal file
24
backend/tests/daily/conftest.py
Normal file
@ -0,0 +1,24 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from danswer.main import fetch_versioned_implementation
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client() -> Generator[TestClient, Any, None]:
|
||||
# Set environment variables
|
||||
os.environ["ENABLE_PAID_ENTERPRISE_EDITION_FEATURES"] = "True"
|
||||
|
||||
# Initialize TestClient with the FastAPI app
|
||||
app = fetch_versioned_implementation(
|
||||
module="danswer.main", attribute="get_application"
|
||||
)()
|
||||
client = TestClient(app)
|
||||
yield client
|
81
backend/tests/daily/llm/test_bedrock.py
Normal file
81
backend/tests/daily/llm/test_bedrock.py
Normal file
@ -0,0 +1,81 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from danswer.llm.llm_provider_options import BEDROCK_PROVIDER_NAME
|
||||
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bedrock_provider() -> WellKnownLLMProviderDescriptor:
|
||||
provider = next(
|
||||
(
|
||||
provider
|
||||
for provider in fetch_available_well_known_llms()
|
||||
if provider.name == BEDROCK_PROVIDER_NAME
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert provider is not None, "Bedrock provider not found"
|
||||
return provider
|
||||
|
||||
|
||||
def test_bedrock_llm_configuration(
|
||||
client: TestClient, bedrock_provider: WellKnownLLMProviderDescriptor
|
||||
) -> None:
|
||||
# Prepare the test request payload
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": BEDROCK_PROVIDER_NAME,
|
||||
"default_model_name": bedrock_provider.default_model,
|
||||
"fast_default_model_name": bedrock_provider.default_fast_model,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"custom_config": {
|
||||
"AWS_REGION_NAME": os.environ.get("AWS_REGION_NAME", "us-east-1"),
|
||||
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
|
||||
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
}
|
||||
|
||||
# Send the test request
|
||||
response = client.post("/admin/llm/test", json=test_request)
|
||||
|
||||
# Assert the response
|
||||
assert (
|
||||
response.status_code == 200
|
||||
), f"Expected status code 200, but got {response.status_code}. Response: {response.text}"
|
||||
|
||||
|
||||
def test_bedrock_llm_configuration_invalid_key(
|
||||
client: TestClient, bedrock_provider: WellKnownLLMProviderDescriptor
|
||||
) -> None:
|
||||
# Prepare the test request payload with invalid credentials
|
||||
test_request: dict[str, Any] = {
|
||||
"provider": BEDROCK_PROVIDER_NAME,
|
||||
"default_model_name": bedrock_provider.default_model,
|
||||
"fast_default_model_name": bedrock_provider.default_fast_model,
|
||||
"api_key": None,
|
||||
"api_base": None,
|
||||
"api_version": None,
|
||||
"custom_config": {
|
||||
"AWS_REGION_NAME": "us-east-1",
|
||||
"AWS_ACCESS_KEY_ID": "invalid_access_key_id",
|
||||
"AWS_SECRET_ACCESS_KEY": "invalid_secret_access_key",
|
||||
},
|
||||
}
|
||||
|
||||
# Send the test request
|
||||
response = client.post("/admin/llm/test", json=test_request)
|
||||
|
||||
# Assert the response
|
||||
assert (
|
||||
response.status_code == 400
|
||||
), f"Expected status code 400, but got {response.status_code}. Response: {response.text}"
|
||||
assert (
|
||||
"Invalid credentials" in response.text
|
||||
or "Invalid Authentication" in response.text
|
||||
), f"Expected error message about invalid credentials, but got: {response.text}"
|
Loading…
x
Reference in New Issue
Block a user