Memory Reduction (#1092)

This commit is contained in:
Yuhong Sun
2024-02-17 21:20:34 -08:00
committed by GitHub
parent d2ce3033a2
commit 927e85319c
6 changed files with 31 additions and 13 deletions

View File

@@ -8,13 +8,15 @@ from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from typing import Literal from typing import Literal
from typing import Optional
from torch import multiprocessing from typing import TYPE_CHECKING
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
if TYPE_CHECKING:
from torch.multiprocessing import Process
JobStatusType = ( JobStatusType = (
Literal["error"] Literal["error"]
@@ -30,7 +32,7 @@ class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`""" """Drop in replacement for `dask.distributed.Future`"""
id: int id: int
process: multiprocessing.Process | None = None process: Optional["Process"] = None
def cancel(self) -> bool: def cancel(self) -> bool:
return self.release() return self.release()
@@ -87,6 +89,8 @@ class SimpleJobClient:
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None: def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask""" """NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
from torch.multiprocessing import Process
self._cleanup_completed_jobs() self._cleanup_completed_jobs()
if len(self.jobs) >= self.n_workers: if len(self.jobs) >= self.n_workers:
logger.debug("No available workers to run job") logger.debug("No available workers to run job")
@@ -95,7 +99,7 @@ class SimpleJobClient:
job_id = self.job_id_counter job_id = self.job_id_counter
self.job_id_counter += 1 self.job_id_counter += 1
process = multiprocessing.Process(target=func, args=args, daemon=True) process = Process(target=func, args=args, daemon=True)
job = SimpleJob(id=job_id, process=process) job = SimpleJob(id=job_id, process=process)
process.start() process.start()

View File

@@ -4,7 +4,6 @@ from datetime import datetime
from datetime import timedelta from datetime import timedelta
from datetime import timezone from datetime import timezone
import torch
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
@@ -286,6 +285,8 @@ def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
"""Entrypoint for indexing run when using dask distributed. """Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed.""" and mark the attempt as failed."""
import torch
try: try:
# set the indexing attempt ID so that all log messages from this process # set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix # will have it added as a prefix

View File

@@ -3,7 +3,6 @@ import time
from datetime import datetime from datetime import datetime
import dask import dask
import torch
from dask.distributed import Client from dask.distributed import Client
from dask.distributed import Future from dask.distributed import Future
from distributed import LocalCluster from distributed import LocalCluster
@@ -61,6 +60,8 @@ def _get_num_threads() -> int:
"""Get # of "threads" to use for ML models in an indexing job. By default uses """Get # of "threads" to use for ML models in an indexing job. By default uses
the torch implementation, which returns the # of physical cores on the machine. the torch implementation, which returns the # of physical cores on the machine.
""" """
import torch
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads()) return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
@@ -457,6 +458,8 @@ def update__main() -> None:
# needed for CUDA to work with multiprocessing # needed for CUDA to work with multiprocessing
# NOTE: needs to be done on application startup # NOTE: needs to be done on application startup
# before any other torch code has been run # before any other torch code has been run
import torch
if not DASK_JOB_CLIENT_ENABLED: if not DASK_JOB_CLIENT_ENABLED:
torch.multiprocessing.set_start_method("spawn") torch.multiprocessing.set_start_method("spawn")

View File

@@ -1,8 +1,6 @@
import abc import abc
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING
from llama_index.text_splitter import SentenceSplitter
from transformers import AutoTokenizer # type:ignore
from danswer.configs.app_configs import BLURB_SIZE from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.app_configs import CHUNK_OVERLAP from danswer.configs.app_configs import CHUNK_OVERLAP
@@ -16,10 +14,15 @@ from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.utils.text_processing import shared_precompare_cleanup from danswer.utils.text_processing import shared_precompare_cleanup
if TYPE_CHECKING:
from transformers import AutoTokenizer # type:ignore
ChunkFunc = Callable[[Document], list[DocAwareChunk]] ChunkFunc = Callable[[Document], list[DocAwareChunk]]
def extract_blurb(text: str, blurb_size: int) -> str: def extract_blurb(text: str, blurb_size: int) -> str:
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_default_tokenizer().tokenize token_count_func = get_default_tokenizer().tokenize
blurb_splitter = SentenceSplitter( blurb_splitter = SentenceSplitter(
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0 tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
@@ -33,11 +36,13 @@ def chunk_large_section(
section_link_text: str, section_link_text: str,
document: Document, document: Document,
start_chunk_id: int, start_chunk_id: int,
tokenizer: AutoTokenizer, tokenizer: "AutoTokenizer",
chunk_size: int = CHUNK_SIZE, chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = CHUNK_OVERLAP, chunk_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE, blurb_size: int = BLURB_SIZE,
) -> list[DocAwareChunk]: ) -> list[DocAwareChunk]:
from llama_index.text_splitter import SentenceSplitter
blurb = extract_blurb(section_text, blurb_size) blurb = extract_blurb(section_text, blurb_size)
sentence_aware_splitter = SentenceSplitter( sentence_aware_splitter = SentenceSplitter(
@@ -155,6 +160,8 @@ def chunk_document(
def split_chunk_text_into_mini_chunks( def split_chunk_text_into_mini_chunks(
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
) -> list[str]: ) -> list[str]:
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_default_tokenizer().tokenize token_count_func = get_default_tokenizer().tokenize
sentence_aware_splitter = SentenceSplitter( sentence_aware_splitter = SentenceSplitter(
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0 tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0

View File

@@ -2,7 +2,7 @@ from typing import Any
from typing import cast from typing import cast
import nltk # type:ignore import nltk # type:ignore
import torch import torch # Import here is fine, API server needs torch anyway and nothing imports main.py
import uvicorn import uvicorn
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import FastAPI from fastapi import FastAPI

View File

@@ -1,4 +1,4 @@
from transformers import AutoTokenizer # type:ignore from typing import TYPE_CHECKING
from danswer.search.models import QueryFlow from danswer.search.models import QueryFlow
from danswer.search.models import SearchType from danswer.search.models import SearchType
@@ -10,8 +10,11 @@ from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
if TYPE_CHECKING:
from transformers import AutoTokenizer # type:ignore
def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token """Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
It splits up even foreign characters and unicode emojis without using UNK""" It splits up even foreign characters and unicode emojis without using UNK"""
tokenized_text = tokenizer.tokenize(text) tokenized_text = tokenizer.tokenize(text)