danswer/backend/tests/daily/llm/test_bedrock.py
2024-10-01 20:09:57 -07:00

82 lines
2.8 KiB
Python

import os
from typing import Any
import pytest
from fastapi.testclient import TestClient
from danswer.llm.llm_provider_options import BEDROCK_PROVIDER_NAME
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
@pytest.fixture
def bedrock_provider() -> WellKnownLLMProviderDescriptor:
provider = next(
(
provider
for provider in fetch_available_well_known_llms()
if provider.name == BEDROCK_PROVIDER_NAME
),
None,
)
assert provider is not None, "Bedrock provider not found"
return provider
def test_bedrock_llm_configuration(
client: TestClient, bedrock_provider: WellKnownLLMProviderDescriptor
) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": BEDROCK_PROVIDER_NAME,
"default_model_name": bedrock_provider.default_model,
"fast_default_model_name": bedrock_provider.default_fast_model,
"api_key": None,
"api_base": None,
"api_version": None,
"custom_config": {
"AWS_REGION_NAME": os.environ.get("AWS_REGION_NAME", "us-east-1"),
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
},
}
# Send the test request
response = client.post("/admin/llm/test", json=test_request)
# Assert the response
assert (
response.status_code == 200
), f"Expected status code 200, but got {response.status_code}. Response: {response.text}"
def test_bedrock_llm_configuration_invalid_key(
client: TestClient, bedrock_provider: WellKnownLLMProviderDescriptor
) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": BEDROCK_PROVIDER_NAME,
"default_model_name": bedrock_provider.default_model,
"fast_default_model_name": bedrock_provider.default_fast_model,
"api_key": None,
"api_base": None,
"api_version": None,
"custom_config": {
"AWS_REGION_NAME": "us-east-1",
"AWS_ACCESS_KEY_ID": "invalid_access_key_id",
"AWS_SECRET_ACCESS_KEY": "invalid_secret_access_key",
},
}
# Send the test request
response = client.post("/admin/llm/test", json=test_request)
# Assert the response
assert (
response.status_code == 400
), f"Expected status code 400, but got {response.status_code}. Response: {response.text}"
assert (
"Invalid credentials" in response.text
or "Invalid Authentication" in response.text
), f"Expected error message about invalid credentials, but got: {response.text}"