Add proper typing such that tests pass mypy (#2301)

* add proper typing such that tests pass mypy

* nit (squash)

* minor update
This commit is contained in:
pablodanswer
2024-09-02 14:03:53 -07:00
committed by GitHub
parent 033ec0b6b1
commit 2b14afe878
6 changed files with 24 additions and 11 deletions

View File

@@ -32,17 +32,24 @@ def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
try: try:
return confluence_call(*args, **kwargs) return confluence_call(*args, **kwargs)
except HTTPError as e: except HTTPError as e:
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
retry_after_header = e.response.headers.get("Retry-After")
if ( if (
e.response.status_code == 429 e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower() or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
): ):
retry_after = None retry_after = None
if retry_after_header is not None:
try: try:
retry_after = int(e.response.headers.get("Retry-After")) retry_after = int(retry_after_header)
except (ValueError, TypeError): except ValueError:
pass pass
if retry_after: if retry_after is not None:
logger.warning( logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..." f"Rate limit hit. Retrying after {retry_after} seconds..."
) )

View File

@@ -85,7 +85,8 @@ def check_internet_connection(url: str) -> None:
response = requests.get(url, timeout=3) response = requests.get(url, timeout=3)
response.raise_for_status() response.raise_for_status()
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
status_code = e.response.status_code # Extract status code from the response, defaulting to -1 if response is None
status_code = e.response.status_code if e.response is not None else -1
error_msg = { error_msg = {
400: "Bad Request", 400: "Bad Request",
401: "Unauthorized", 401: "Unauthorized",

View File

@@ -28,7 +28,7 @@ def get_default_admin_user_emails() -> list[str]:
get_default_admin_user_emails_fn: Callable[ get_default_admin_user_emails_fn: Callable[
[], list[str] [], list[str]
] = fetch_versioned_implementation_with_fallback( ] = fetch_versioned_implementation_with_fallback(
"danswer.auth.users", "get_default_admin_user_emails_", lambda: [] "danswer.auth.users", "get_default_admin_user_emails_", lambda: list[str]()
) )
return get_default_admin_user_emails_fn() return get_default_admin_user_emails_fn()

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from io import BytesIO from io import BytesIO
from typing import Any
from typing import cast from typing import cast
from uuid import uuid4 from uuid import uuid4
@@ -73,5 +75,7 @@ def save_file_from_url(url: str) -> str:
def save_files_from_urls(urls: list[str]) -> list[str]: def save_files_from_urls(urls: list[str]) -> list[str]:
funcs = [(save_file_from_url, (url,)) for url in urls] funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(save_file_from_url, (url,)) for url in urls
]
return run_functions_tuples_in_parallel(funcs) return run_functions_tuples_in_parallel(funcs)

View File

@@ -253,8 +253,8 @@ def search_postprocessing(
if not retrieved_sections: if not retrieved_sections:
# Avoids trying to rerank an empty list which throws an error # Avoids trying to rerank an empty list which throws an error
yield [] yield cast(list[InferenceSection], [])
yield [] yield cast(list[SectionRelevancePiece], [])
return return
rerank_task_id = None rerank_task_id = None

View File

@@ -1,3 +1,4 @@
from collections.abc import Callable
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
@@ -47,7 +48,7 @@ class ToolRunner:
def check_which_tools_should_run_for_non_tool_calling_llm( def check_which_tools_should_run_for_non_tool_calling_llm(
tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM
) -> list[dict[str, Any] | None]: ) -> list[dict[str, Any] | None]:
tool_args_list = [ tool_args_list: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(tool.get_args_for_non_tool_calling_llm, (query, history, llm)) (tool.get_args_for_non_tool_calling_llm, (query, history, llm))
for tool in tools for tool in tools
] ]