diff --git a/backend/danswer/background/indexing/job_client.py b/backend/danswer/background/indexing/job_client.py index 8e22f8e45..d37690627 100644 --- a/backend/danswer/background/indexing/job_client.py +++ b/backend/danswer/background/indexing/job_client.py @@ -8,13 +8,15 @@ from collections.abc import Callable from dataclasses import dataclass from typing import Any from typing import Literal - -from torch import multiprocessing +from typing import Optional +from typing import TYPE_CHECKING from danswer.utils.logger import setup_logger logger = setup_logger() +if TYPE_CHECKING: + from torch.multiprocessing import Process JobStatusType = ( Literal["error"] @@ -30,7 +32,7 @@ class SimpleJob: """Drop in replacement for `dask.distributed.Future`""" id: int - process: multiprocessing.Process | None = None + process: Optional["Process"] = None def cancel(self) -> bool: return self.release() @@ -87,6 +89,8 @@ class SimpleJobClient: def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None: """NOTE: `pure` arg is needed so this can be a drop in replacement for Dask""" + from torch.multiprocessing import Process + self._cleanup_completed_jobs() if len(self.jobs) >= self.n_workers: logger.debug("No available workers to run job") @@ -95,7 +99,7 @@ class SimpleJobClient: job_id = self.job_id_counter self.job_id_counter += 1 - process = multiprocessing.Process(target=func, args=args, daemon=True) + process = Process(target=func, args=args, daemon=True) job = SimpleJob(id=job_id, process=process) process.start() diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 79c5f7903..2399facf1 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -4,7 +4,6 @@ from datetime import datetime from datetime import timedelta from datetime import timezone -import torch from sqlalchemy.orm import Session from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt @@ -286,6 +285,8 @@ def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None: """Entrypoint for indexing run when using dask distributed. Wraps the actual logic in a `try` block so that we can catch any exceptions and mark the attempt as failed.""" + import torch + try: # set the indexing attempt ID so that all log messages from this process # will have it added as a prefix diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 851ada5d0..7c28e8d4c 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -3,7 +3,6 @@ import time from datetime import datetime import dask -import torch from dask.distributed import Client from dask.distributed import Future from distributed import LocalCluster @@ -61,6 +60,8 @@ def _get_num_threads() -> int: """Get # of "threads" to use for ML models in an indexing job. By default uses the torch implementation, which returns the # of physical cores on the machine. """ + import torch + return max(MIN_THREADS_ML_MODELS, torch.get_num_threads()) @@ -457,6 +458,8 @@ def update__main() -> None: # needed for CUDA to work with multiprocessing # NOTE: needs to be done on application startup # before any other torch code has been run + import torch + if not DASK_JOB_CLIENT_ENABLED: torch.multiprocessing.set_start_method("spawn") diff --git a/backend/danswer/indexing/chunker.py b/backend/danswer/indexing/chunker.py index 1166c93ff..c30162327 100644 --- a/backend/danswer/indexing/chunker.py +++ b/backend/danswer/indexing/chunker.py @@ -1,8 +1,6 @@ import abc from collections.abc import Callable - -from llama_index.text_splitter import SentenceSplitter -from transformers import AutoTokenizer # type:ignore +from typing import TYPE_CHECKING from danswer.configs.app_configs import BLURB_SIZE from danswer.configs.app_configs import CHUNK_OVERLAP @@ -16,10 +14,15 @@ from danswer.search.search_nlp_models import get_default_tokenizer from danswer.utils.text_processing import shared_precompare_cleanup +if TYPE_CHECKING: + from transformers import AutoTokenizer # type:ignore + ChunkFunc = Callable[[Document], list[DocAwareChunk]] def extract_blurb(text: str, blurb_size: int) -> str: + from llama_index.text_splitter import SentenceSplitter + token_count_func = get_default_tokenizer().tokenize blurb_splitter = SentenceSplitter( tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0 @@ -33,11 +36,13 @@ def chunk_large_section( section_link_text: str, document: Document, start_chunk_id: int, - tokenizer: AutoTokenizer, + tokenizer: "AutoTokenizer", chunk_size: int = CHUNK_SIZE, chunk_overlap: int = CHUNK_OVERLAP, blurb_size: int = BLURB_SIZE, ) -> list[DocAwareChunk]: + from llama_index.text_splitter import SentenceSplitter + blurb = extract_blurb(section_text, blurb_size) sentence_aware_splitter = SentenceSplitter( @@ -155,6 +160,8 @@ def chunk_document( def split_chunk_text_into_mini_chunks( chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE ) -> list[str]: + from llama_index.text_splitter import SentenceSplitter + token_count_func = get_default_tokenizer().tokenize sentence_aware_splitter = SentenceSplitter( tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0 diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 7f5b52ed0..5f8bdd85b 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -2,7 +2,7 @@ from typing import Any from typing import cast import nltk # type:ignore -import torch +import torch # Import here is fine, API server needs torch anyway and nothing imports main.py import uvicorn from fastapi import APIRouter from fastapi import FastAPI diff --git a/backend/danswer/search/danswer_helper.py b/backend/danswer/search/danswer_helper.py index e3de6f923..d5dbeb8a3 100644 --- a/backend/danswer/search/danswer_helper.py +++ b/backend/danswer/search/danswer_helper.py @@ -1,4 +1,4 @@ -from transformers import AutoTokenizer # type:ignore +from typing import TYPE_CHECKING from danswer.search.models import QueryFlow from danswer.search.models import SearchType @@ -10,8 +10,11 @@ from danswer.utils.logger import setup_logger logger = setup_logger() +if TYPE_CHECKING: + from transformers import AutoTokenizer # type:ignore -def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int: + +def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int: """Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token It splits up even foreign characters and unicode emojis without using UNK""" tokenized_text = tokenizer.tokenize(text)