Merge pull request #3893 from onyx-dot-app/mypy_random

Mypy random fixes
This commit is contained in:
rkuo-danswer 2025-02-04 16:02:18 -08:00 committed by GitHub
commit 5854b39dd4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 61 additions and 13 deletions

View File

View File

@ -286,6 +286,7 @@ def prepare_authorization_request(
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
@ -554,6 +555,7 @@ def handle_google_drive_oauth_callback(
)
session_json = session_json_bytes.decode("utf-8")
session: GoogleDriveOAuth.OAuthSession
try:
session = GoogleDriveOAuth.parse_session(session_json)

View File

@ -245,6 +245,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
referral_source=referral_source,
request=request,
)
user: User
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_is_invited(user_create.email)
@ -368,6 +370,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"refresh_token": refresh_token,
}
user: User
try:
# Attempt to get user by OAuth account
user = await self.get_by_oauth_account(oauth_name, account_id)
@ -1043,6 +1047,8 @@ async def api_key_dep(
if AUTH_TYPE == AuthType.DISABLED:
return None
user: User | None = None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")

View File

@ -586,11 +586,12 @@ def connector_indexing_proxy_task(
# if the job is done, clean up and break
if job.done():
exit_code: int | None
try:
if job.status == "error":
ignore_exitcode = False
exit_code: int | None = None
exit_code = None
if job.process:
exit_code = job.process.exitcode

View File

@ -446,6 +446,7 @@ def try_creating_indexing_task(
if not acquired:
return None
redis_connector_index: RedisConnectorIndex
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)

View File

@ -747,6 +747,7 @@ def cloud_check_alembic() -> bool | None:
revision_counts: dict[str, int] = {}
out_of_date_tenants: dict[str, str | None] = {}
top_revision: str = ""
tenant_ids: list[str] | list[None] = []
try:
# map each tenant_id to its revision

View File

@ -168,6 +168,7 @@ def document_by_cc_pair_cleanup_task(
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
@ -247,6 +248,7 @@ def cloud_beat_task_generator(
return None
last_lock_time = time.monotonic()
tenant_ids: list[str] | list[None] = []
try:
tenant_ids = get_all_tenant_ids()

View File

@ -1033,6 +1033,7 @@ def vespa_metadata_sync_task(
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"

View File

@ -239,6 +239,7 @@ def _run_indexing(
callback=callback,
)
tracer: OnyxTracer
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = OnyxTracer()
@ -255,6 +256,8 @@ def _run_indexing(
document_count = 0
chunk_count = 0
run_end_dt = None
tracer_counter: int
for ind, (window_start, window_end) in enumerate(
get_time_windows_for_index_attempt(
last_successful_run=datetime.fromtimestamp(
@ -265,6 +268,7 @@ def _run_indexing(
):
cc_pair_loop: ConnectorCredentialPair | None = None
index_attempt_loop: IndexAttempt | None = None
tracer_counter = 0
try:
window_start = max(
@ -289,7 +293,6 @@ def _run_indexing(
tenant_id=tenant_id,
)
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():

View File

@ -87,6 +87,7 @@ from onyx.file_store.utils import save_files
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.natural_language_processing.utils import get_tokenizer
@ -349,7 +350,8 @@ def stream_chat_message_objects(
new_msg_req.chunks_above = 0
new_msg_req.chunks_below = 0
llm = None
llm: LLM
try:
user_id = user.id if user is not None else None

View File

@ -369,11 +369,12 @@ class AirtableConnector(LoadConnector):
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 8
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
record_documents: list[Document] = []
# Process records in batches
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
record_documents: list[Document] = []
record_documents = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks

View File

@ -99,18 +99,18 @@ class AsanaAPI:
project = self.project_api.get_project(project_gid, opts={})
if project["archived"]:
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
return []
yield from []
if not project["team"] or not project["team"]["gid"]:
logger.info(
f"Skipping project without a team: {project['name']} ({project_gid})"
)
return []
yield from []
if project["privacy_setting"] == "private":
if self.team_gid and project["team"]["gid"] != self.team_gid:
logger.info(
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
)
return []
yield from []
else:
logger.info(
f"Processing private project in configured team: {project['name']} ({project_gid})"

View File

@ -26,6 +26,7 @@ def _get_google_service(
creds: ServiceAccountCredentials | OAuthCredentials,
user_email: str | None = None,
) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService:
service: Resource
if isinstance(creds, ServiceAccountCredentials):
creds = creds.with_subject(user_email)
service = build(service_name, service_version, credentials=creds)

View File

@ -59,6 +59,7 @@ def _clean_salesforce_dict(data: dict | list) -> dict | list:
elif isinstance(data, list):
filtered_list = []
for item in data:
filtered_item: dict | list
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
# Only add non-empty dictionaries or lists

View File

@ -221,6 +221,8 @@ def insert_document_set(
group_ids=document_set_creation_request.groups or [],
)
new_document_set_row: DocumentSetDBModel
ds_cc_pairs: list[DocumentSet__ConnectorCredentialPair]
try:
new_document_set_row = DocumentSetDBModel(
name=document_set_creation_request.name,

View File

@ -365,7 +365,7 @@ def extract_file_text(
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
)
# Fall through to normal processing
final_extension: str
if file_name or extension:
if extension is not None:
final_extension = extension

View File

@ -223,6 +223,8 @@ class Chunker:
large_chunk_id=None,
)
section_link_text: str
for section_idx, section in enumerate(document.sections):
section_text = clean_text(section.text)
section_link_text = section.link or ""

View File

@ -5,6 +5,7 @@ import sys
import threading
import time
from collections.abc import Callable
from contextvars import Token
from threading import Event
from types import FrameType
from typing import Any
@ -250,6 +251,8 @@ class SlackbotHandler:
"""
all_tenants = get_all_tenant_ids()
token: Token[str]
# 1) Try to acquire locks for new tenants
for tenant_id in all_tenants:
if (
@ -771,6 +774,7 @@ def process_message(
client=client.web_client, channel_id=channel
)
token: Token[str] | None = None
# Set the current tenant ID at the beginning for all DB calls within this thread
if client.tenant_id:
logger.info(f"Setting tenant ID to {client.tenant_id}")
@ -825,7 +829,7 @@ def process_message(
if notify_no_answer:
apologize_for_fail(details, client)
finally:
if client.tenant_id:
if token:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

View File

@ -518,7 +518,7 @@ def read_slack_thread(
message_type = MessageType.USER
else:
self_slack_bot_id = get_onyx_bot_slack_bot_id(client)
blocks: Any
if reply.get("user") == self_slack_bot_id:
# OnyxBot response
message_type = MessageType.ASSISTANT

View File

@ -1,5 +1,6 @@
import time
from datetime import datetime
from typing import Any
from typing import cast
from uuid import uuid4
@ -96,7 +97,7 @@ class RedisConnectorPermissionSync:
@property
def payload(self) -> RedisConnectorPermissionSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
import redis
@ -82,7 +83,7 @@ class RedisConnectorExternalGroupSync:
@property
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None

View File

@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
from uuid import uuid4
@ -91,7 +92,7 @@ class RedisConnectorIndex:
@property
def payload(self) -> RedisConnectorIndexPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
fence_bytes = cast(Any, self.redis.get(self.fence_key))
if fence_bytes is None:
return None

View File

@ -271,6 +271,8 @@ def bulk_invite_users(
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
new_invited_emails = []
email: str
try:
for email in emails:
email_info = validate_email(email)

View File

@ -198,6 +198,7 @@ def process_all_chat_feedback(onyx_url: str, api_key: str | None) -> None:
r_sessions = get_chat_sessions(onyx_url, headers, user_id)
logger.info(f"user={user_id} num_sessions={len(r_sessions.sessions)}")
for session in r_sessions.sessions:
s: ChatSessionSnapshot
try:
s = get_session_history(onyx_url, headers, session.id)
except requests.exceptions.HTTPError:

View File

@ -6,6 +6,7 @@ from unittest.mock import patch
import pytest
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.models import Document
@pytest.fixture
@ -41,6 +42,10 @@ def test_confluence_connector_basic(
assert len(doc_batch) == 3
page_within_a_page_doc: Document | None = None
page_doc: Document | None = None
txt_doc: Document | None = None
for doc in doc_batch:
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
page_doc = doc
@ -49,6 +54,7 @@ def test_confluence_connector_basic(
elif doc.semantic_identifier == "Page Within A Page":
page_within_a_page_doc = doc
assert page_within_a_page_doc is not None
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
assert page_within_a_page_doc.primary_owners
assert page_within_a_page_doc.primary_owners[0].email == "hagen@danswer.ai"
@ -62,6 +68,7 @@ def test_confluence_connector_basic(
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/pages/200769540/Page+Within+A+Page"
)
assert page_doc is not None
assert page_doc.semantic_identifier == "DailyConnectorTestSpace Home"
assert page_doc.metadata["labels"] == ["testlabel"]
assert page_doc.primary_owners
@ -75,6 +82,7 @@ def test_confluence_connector_basic(
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
)
assert txt_doc is not None
assert txt_doc.semantic_identifier == "small-file.txt"
assert len(txt_doc.sections) == 1
assert txt_doc.sections[0].text == "small"

View File

@ -110,6 +110,8 @@ def test_docs_retrieval(
for doc in retrieved_docs:
id = doc.id
retrieved_primary_owner_emails: set[str | None] = set()
retrieved_secondary_owner_emails: set[str | None] = set()
if doc.primary_owners:
retrieved_primary_owner_emails = set(
[owner.email for owner in doc.primary_owners]

View File

@ -165,6 +165,7 @@ class UserManager:
target_status: bool,
user_performing_action: DATestUser,
) -> DATestUser:
url_substring: str
if target_status is True:
url_substring = "activate"
elif target_status is False:

View File

@ -54,6 +54,7 @@ def google_drive_test_env_setup() -> (
service_account_key = os.environ["FULL_CONTROL_DRIVE_SERVICE_ACCOUNT"]
drive_id: str | None = None
drive_service: GoogleDriveService | None = None
try:
credentials = {