Add proper typing such that tests pass mypy (#2301)

* add proper typing such that tests pass mypy

* nit (squash)

* minor update
This commit is contained in:
pablodanswer
2024-09-02 14:03:53 -07:00
committed by GitHub
parent 033ec0b6b1
commit 2b14afe878
6 changed files with 24 additions and 11 deletions

View File

@@ -32,17 +32,24 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
try:
return confluence_call(*args, **kwargs)
except HTTPError as e:
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
retry_after_header = e.response.headers.get("Retry-After")
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
retry_after = None
try:
retry_after = int(e.response.headers.get("Retry-After"))
except (ValueError, TypeError):
pass
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
except ValueError:
pass
if retry_after:
if retry_after is not None:
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)

View File

@@ -85,7 +85,8 @@ def check_internet_connection(url: str) -> None:
response = requests.get(url, timeout=3)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
status_code = e.response.status_code
# Extract status code from the response, defaulting to -1 if response is None
status_code = e.response.status_code if e.response is not None else -1
error_msg = {
400: "Bad Request",
401: "Unauthorized",

View File

@@ -28,7 +28,7 @@ def get_default_admin_user_emails() -> list[str]:
get_default_admin_user_emails_fn: Callable[
[], list[str]
] = fetch_versioned_implementation_with_fallback(
"danswer.auth.users", "get_default_admin_user_emails_", lambda: []
"danswer.auth.users", "get_default_admin_user_emails_", lambda: list[str]()
)
return get_default_admin_user_emails_fn()

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from io import BytesIO
from typing import Any
from typing import cast
from uuid import uuid4
@@ -73,5 +75,7 @@ def save_file_from_url(url: str) -> str:
def save_files_from_urls(urls: list[str]) -> list[str]:
funcs = [(save_file_from_url, (url,)) for url in urls]
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(save_file_from_url, (url,)) for url in urls
]
return run_functions_tuples_in_parallel(funcs)

View File

@@ -253,8 +253,8 @@ def search_postprocessing(
if not retrieved_sections:
# Avoids trying to rerank an empty list which throws an error
yield []
yield []
yield cast(list[InferenceSection], [])
yield cast(list[SectionRelevancePiece], [])
return
rerank_task_id = None

View File

@@ -1,3 +1,4 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
@@ -47,7 +48,7 @@ class ToolRunner:
def check_which_tools_should_run_for_non_tool_calling_llm(
tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM
) -> list[dict[str, Any] | None]:
tool_args_list = [
tool_args_list: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(tool.get_args_for_non_tool_calling_llm, (query, history, llm))
for tool in tools
]