Memory Reduction (#1092)

This commit is contained in:
Yuhong Sun 2024-02-17 21:20:34 -08:00 committed by GitHub
parent d2ce3033a2
commit 927e85319c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 31 additions and 13 deletions

View File

@ -8,13 +8,15 @@ from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from typing import Literal
from torch import multiprocessing
from typing import Optional
from typing import TYPE_CHECKING
from danswer.utils.logger import setup_logger
logger = setup_logger()
if TYPE_CHECKING:
from torch.multiprocessing import Process
JobStatusType = (
Literal["error"]
@ -30,7 +32,7 @@ class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""
id: int
process: multiprocessing.Process | None = None
process: Optional["Process"] = None
def cancel(self) -> bool:
return self.release()
@ -87,6 +89,8 @@ class SimpleJobClient:
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
from torch.multiprocessing import Process
self._cleanup_completed_jobs()
if len(self.jobs) >= self.n_workers:
logger.debug("No available workers to run job")
@ -95,7 +99,7 @@ class SimpleJobClient:
job_id = self.job_id_counter
self.job_id_counter += 1
process = multiprocessing.Process(target=func, args=args, daemon=True)
process = Process(target=func, args=args, daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@ -4,7 +4,6 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
import torch
from sqlalchemy.orm import Session
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
@ -286,6 +285,8 @@ def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
import torch
try:
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix

View File

@ -3,7 +3,6 @@ import time
from datetime import datetime
import dask
import torch
from dask.distributed import Client
from dask.distributed import Future
from distributed import LocalCluster
@ -61,6 +60,8 @@ def _get_num_threads() -> int:
"""Get # of "threads" to use for ML models in an indexing job. By default uses
the torch implementation, which returns the # of physical cores on the machine.
"""
import torch
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
@ -457,6 +458,8 @@ def update__main() -> None:
# needed for CUDA to work with multiprocessing
# NOTE: needs to be done on application startup
# before any other torch code has been run
import torch
if not DASK_JOB_CLIENT_ENABLED:
torch.multiprocessing.set_start_method("spawn")

View File

@ -1,8 +1,6 @@
import abc
from collections.abc import Callable
from llama_index.text_splitter import SentenceSplitter
from transformers import AutoTokenizer # type:ignore
from typing import TYPE_CHECKING
from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.app_configs import CHUNK_OVERLAP
@ -16,10 +14,15 @@ from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.utils.text_processing import shared_precompare_cleanup
if TYPE_CHECKING:
from transformers import AutoTokenizer # type:ignore
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
def extract_blurb(text: str, blurb_size: int) -> str:
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_default_tokenizer().tokenize
blurb_splitter = SentenceSplitter(
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
@ -33,11 +36,13 @@ def chunk_large_section(
section_link_text: str,
document: Document,
start_chunk_id: int,
tokenizer: AutoTokenizer,
tokenizer: "AutoTokenizer",
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
) -> list[DocAwareChunk]:
from llama_index.text_splitter import SentenceSplitter
blurb = extract_blurb(section_text, blurb_size)
sentence_aware_splitter = SentenceSplitter(
@ -155,6 +160,8 @@ def chunk_document(
def split_chunk_text_into_mini_chunks(
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
) -> list[str]:
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_default_tokenizer().tokenize
sentence_aware_splitter = SentenceSplitter(
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0

View File

@ -2,7 +2,7 @@ from typing import Any
from typing import cast
import nltk # type:ignore
import torch
import torch # Import here is fine, API server needs torch anyway and nothing imports main.py
import uvicorn
from fastapi import APIRouter
from fastapi import FastAPI

View File

@ -1,4 +1,4 @@
from transformers import AutoTokenizer # type:ignore
from typing import TYPE_CHECKING
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
@ -10,8 +10,11 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
if TYPE_CHECKING:
from transformers import AutoTokenizer # type:ignore
def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
It splits up even foreign characters and unicode emojis without using UNK"""
tokenized_text = tokenizer.tokenize(text)