Update litellm to fix bedrock models ()

This commit is contained in:
Chris Weaver 2024-10-01 20:09:57 -07:00 committed by GitHub
parent fffb9c155a
commit b8232e0681
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 170 additions and 5 deletions
.github/workflows
backend
danswer/llm
requirements
tests/daily

@ -0,0 +1,58 @@
name: Connector Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# OpenAI
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "pip"
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
- name: Alert on Failure
if: failure() && github.event_name == 'schedule'
env:
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
run: |
curl -X POST \
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK

@ -290,10 +290,12 @@ class DefaultMultiLLM(LLM):
return litellm.completion(
# model choice
model=f"{self.config.model_provider}/{self.config.model_name}",
api_key=self._api_key,
base_url=self._api_base,
api_version=self._api_version,
custom_llm_provider=self._custom_llm_provider,
# NOTE: have to pass in None instead of empty string for these
# otherwise litellm can have some issues with bedrock
api_key=self._api_key or None,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
# actual input
messages=prompt,
tools=tools,

@ -28,7 +28,7 @@ jsonref==1.1.0
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.47.1
litellm==1.48.7
llama-index==0.9.45
Mako==1.2.4
msal==1.28.0

@ -0,0 +1,24 @@
import os
from collections.abc import Generator
from typing import Any
import pytest
from fastapi.testclient import TestClient
from danswer.main import fetch_versioned_implementation
from danswer.utils.logger import setup_logger
logger = setup_logger()
@pytest.fixture(scope="function")
def client() -> Generator[TestClient, Any, None]:
# Set environment variables
os.environ["ENABLE_PAID_ENTERPRISE_EDITION_FEATURES"] = "True"
# Initialize TestClient with the FastAPI app
app = fetch_versioned_implementation(
module="danswer.main", attribute="get_application"
)()
client = TestClient(app)
yield client

@ -0,0 +1,81 @@
import os
from typing import Any
import pytest
from fastapi.testclient import TestClient
from danswer.llm.llm_provider_options import BEDROCK_PROVIDER_NAME
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
@pytest.fixture
def bedrock_provider() -> WellKnownLLMProviderDescriptor:
provider = next(
(
provider
for provider in fetch_available_well_known_llms()
if provider.name == BEDROCK_PROVIDER_NAME
),
None,
)
assert provider is not None, "Bedrock provider not found"
return provider
def test_bedrock_llm_configuration(
client: TestClient, bedrock_provider: WellKnownLLMProviderDescriptor
) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": BEDROCK_PROVIDER_NAME,
"default_model_name": bedrock_provider.default_model,
"fast_default_model_name": bedrock_provider.default_fast_model,
"api_key": None,
"api_base": None,
"api_version": None,
"custom_config": {
"AWS_REGION_NAME": os.environ.get("AWS_REGION_NAME", "us-east-1"),
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
},
}
# Send the test request
response = client.post("/admin/llm/test", json=test_request)
# Assert the response
assert (
response.status_code == 200
), f"Expected status code 200, but got {response.status_code}. Response: {response.text}"
def test_bedrock_llm_configuration_invalid_key(
client: TestClient, bedrock_provider: WellKnownLLMProviderDescriptor
) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": BEDROCK_PROVIDER_NAME,
"default_model_name": bedrock_provider.default_model,
"fast_default_model_name": bedrock_provider.default_fast_model,
"api_key": None,
"api_base": None,
"api_version": None,
"custom_config": {
"AWS_REGION_NAME": "us-east-1",
"AWS_ACCESS_KEY_ID": "invalid_access_key_id",
"AWS_SECRET_ACCESS_KEY": "invalid_secret_access_key",
},
}
# Send the test request
response = client.post("/admin/llm/test", json=test_request)
# Assert the response
assert (
response.status_code == 400
), f"Expected status code 400, but got {response.status_code}. Response: {response.text}"
assert (
"Invalid credentials" in response.text
or "Invalid Authentication" in response.text
), f"Expected error message about invalid credentials, but got: {response.text}"