correct serialization and validation of threadsafedict

This commit is contained in:
Evan Lohn 2025-03-13 19:38:49 -07:00
parent cdbcf336a0
commit 06ca775dbb
5 changed files with 21 additions and 10 deletions

View File

@ -158,6 +158,8 @@ try:
except ValueError:
INDEX_BATCH_SIZE = 16
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
# Below are intended to match the env variables names used by the official postgres docker image
# https://hub.docker.com/_/postgres
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"

View File

@ -14,6 +14,7 @@ from googleapiclient.errors import HttpError # type: ignore
from typing_extensions import override
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_DRIVE_WORKERS
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
@ -435,7 +436,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
)
for email in all_org_emails
]
yield from parallel_yield(user_retrieval_gens, max_workers=10)
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve

View File

@ -1,6 +1,8 @@
from enum import Enum
from typing import Any
from pydantic import field_serializer
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.threadpool_concurrency import ThreadSafeDict
@ -82,7 +84,7 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
# The latest timestamp of a file that has been retrieved per completion key.
# See curr_completion_key for more details on completion keys.
completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch] = ThreadSafeDict()
completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch]
# cached version of the drive and folder ids to retrieve
drive_ids_to_retrieve: list[str] | None = None
@ -91,8 +93,8 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
# cached user emails
user_emails: list[str] | None = None
# @field_serializer("completion_map")
# def serialize_completion_map(
# self, completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch], _info: Any
# ) -> dict[str, SecondsSinceUnixEpoch]:
# return completion_map._dict
@field_serializer("completion_map")
def serialize_completion_map(
self, completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch], _info: Any
) -> dict[str, SecondsSinceUnixEpoch]:
return completion_map._dict

View File

@ -80,7 +80,15 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return handler(dict[KT, VT])
return core_schema.no_info_after_validator_function(
cls.validate, handler(dict[KT, VT])
)
@classmethod
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
if isinstance(v, dict):
return ThreadSafeDict(v)
return v
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
return ThreadSafeDict(copy.deepcopy(self._dict))
@ -325,7 +333,6 @@ def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
for future in done:
ind, result = future.result()
print(ind, result)
if result is not None:
yield result
del future_to_index[future]

View File

@ -389,7 +389,6 @@ def test_parallel_yield_non_blocking() -> None:
results = list(parallel_yield(gens))
print(results)
# Verify no values are missing
assert len(results) == 300 # Should have all values from 0 to 299
assert sorted(results) == list(range(300))