mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-12 17:10:54 +02:00
82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
import os
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
|
|
from danswer.llm.llm_provider_options import BEDROCK_PROVIDER_NAME
|
|
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
|
|
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
|
|
|
|
|
@pytest.fixture
|
|
def bedrock_provider() -> WellKnownLLMProviderDescriptor:
|
|
provider = next(
|
|
(
|
|
provider
|
|
for provider in fetch_available_well_known_llms()
|
|
if provider.name == BEDROCK_PROVIDER_NAME
|
|
),
|
|
None,
|
|
)
|
|
assert provider is not None, "Bedrock provider not found"
|
|
return provider
|
|
|
|
|
|
def test_bedrock_llm_configuration(
|
|
client: TestClient, bedrock_provider: WellKnownLLMProviderDescriptor
|
|
) -> None:
|
|
# Prepare the test request payload
|
|
test_request: dict[str, Any] = {
|
|
"provider": BEDROCK_PROVIDER_NAME,
|
|
"default_model_name": bedrock_provider.default_model,
|
|
"fast_default_model_name": bedrock_provider.default_fast_model,
|
|
"api_key": None,
|
|
"api_base": None,
|
|
"api_version": None,
|
|
"custom_config": {
|
|
"AWS_REGION_NAME": os.environ.get("AWS_REGION_NAME", "us-east-1"),
|
|
"AWS_ACCESS_KEY_ID": os.environ.get("AWS_ACCESS_KEY_ID"),
|
|
"AWS_SECRET_ACCESS_KEY": os.environ.get("AWS_SECRET_ACCESS_KEY"),
|
|
},
|
|
}
|
|
|
|
# Send the test request
|
|
response = client.post("/admin/llm/test", json=test_request)
|
|
|
|
# Assert the response
|
|
assert (
|
|
response.status_code == 200
|
|
), f"Expected status code 200, but got {response.status_code}. Response: {response.text}"
|
|
|
|
|
|
def test_bedrock_llm_configuration_invalid_key(
|
|
client: TestClient, bedrock_provider: WellKnownLLMProviderDescriptor
|
|
) -> None:
|
|
# Prepare the test request payload with invalid credentials
|
|
test_request: dict[str, Any] = {
|
|
"provider": BEDROCK_PROVIDER_NAME,
|
|
"default_model_name": bedrock_provider.default_model,
|
|
"fast_default_model_name": bedrock_provider.default_fast_model,
|
|
"api_key": None,
|
|
"api_base": None,
|
|
"api_version": None,
|
|
"custom_config": {
|
|
"AWS_REGION_NAME": "us-east-1",
|
|
"AWS_ACCESS_KEY_ID": "invalid_access_key_id",
|
|
"AWS_SECRET_ACCESS_KEY": "invalid_secret_access_key",
|
|
},
|
|
}
|
|
|
|
# Send the test request
|
|
response = client.post("/admin/llm/test", json=test_request)
|
|
|
|
# Assert the response
|
|
assert (
|
|
response.status_code == 400
|
|
), f"Expected status code 400, but got {response.status_code}. Response: {response.text}"
|
|
assert (
|
|
"Invalid credentials" in response.text
|
|
or "Invalid Authentication" in response.text
|
|
), f"Expected error message about invalid credentials, but got: {response.text}"
|