mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 13:22:42 +01:00
correct serialization and validation of threadsafedict
This commit is contained in:
parent
cdbcf336a0
commit
06ca775dbb
@ -158,6 +158,8 @@ try:
|
||||
except ValueError:
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
||||
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
|
||||
|
||||
# Below are intended to match the env variables names used by the official postgres docker image
|
||||
# https://hub.docker.com/_/postgres
|
||||
POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
|
||||
|
@ -14,6 +14,7 @@ from googleapiclient.errors import HttpError # type: ignore
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import MAX_DRIVE_WORKERS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
@ -435,7 +436,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
)
|
||||
for email in all_org_emails
|
||||
]
|
||||
yield from parallel_yield(user_retrieval_gens, max_workers=10)
|
||||
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
|
||||
|
||||
remaining_folders = (
|
||||
drive_ids_to_retrieve | folder_ids_to_retrieve
|
||||
|
@ -1,6 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import field_serializer
|
||||
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeDict
|
||||
@ -82,7 +84,7 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
# The latest timestamp of a file that has been retrieved per completion key.
|
||||
# See curr_completion_key for more details on completion keys.
|
||||
completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch] = ThreadSafeDict()
|
||||
completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch]
|
||||
|
||||
# cached version of the drive and folder ids to retrieve
|
||||
drive_ids_to_retrieve: list[str] | None = None
|
||||
@ -91,8 +93,8 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
# cached user emails
|
||||
user_emails: list[str] | None = None
|
||||
|
||||
# @field_serializer("completion_map")
|
||||
# def serialize_completion_map(
|
||||
# self, completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch], _info: Any
|
||||
# ) -> dict[str, SecondsSinceUnixEpoch]:
|
||||
# return completion_map._dict
|
||||
@field_serializer("completion_map")
|
||||
def serialize_completion_map(
|
||||
self, completion_map: ThreadSafeDict[str, SecondsSinceUnixEpoch], _info: Any
|
||||
) -> dict[str, SecondsSinceUnixEpoch]:
|
||||
return completion_map._dict
|
||||
|
@ -80,7 +80,15 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
return handler(dict[KT, VT])
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls.validate, handler(dict[KT, VT])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
if isinstance(v, dict):
|
||||
return ThreadSafeDict(v)
|
||||
return v
|
||||
|
||||
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
|
||||
return ThreadSafeDict(copy.deepcopy(self._dict))
|
||||
@ -325,7 +333,6 @@ def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R
|
||||
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
|
||||
for future in done:
|
||||
ind, result = future.result()
|
||||
print(ind, result)
|
||||
if result is not None:
|
||||
yield result
|
||||
del future_to_index[future]
|
||||
|
@ -389,7 +389,6 @@ def test_parallel_yield_non_blocking() -> None:
|
||||
|
||||
results = list(parallel_yield(gens))
|
||||
|
||||
print(results)
|
||||
# Verify no values are missing
|
||||
assert len(results) == 300 # Should have all values from 0 to 299
|
||||
assert sorted(results) == list(range(300))
|
||||
|
Loading…
x
Reference in New Issue
Block a user