mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 11:28:09 +02:00
Memory Reduction (#1092)
This commit is contained in:
parent
d2ce3033a2
commit
927e85319c
@ -8,13 +8,15 @@ from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from torch import multiprocessing
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.multiprocessing import Process
|
||||
|
||||
JobStatusType = (
|
||||
Literal["error"]
|
||||
@ -30,7 +32,7 @@ class SimpleJob:
|
||||
"""Drop in replacement for `dask.distributed.Future`"""
|
||||
|
||||
id: int
|
||||
process: multiprocessing.Process | None = None
|
||||
process: Optional["Process"] = None
|
||||
|
||||
def cancel(self) -> bool:
|
||||
return self.release()
|
||||
@ -87,6 +89,8 @@ class SimpleJobClient:
|
||||
|
||||
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
|
||||
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
||||
from torch.multiprocessing import Process
|
||||
|
||||
self._cleanup_completed_jobs()
|
||||
if len(self.jobs) >= self.n_workers:
|
||||
logger.debug("No available workers to run job")
|
||||
@ -95,7 +99,7 @@ class SimpleJobClient:
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
process = multiprocessing.Process(target=func, args=args, daemon=True)
|
||||
process = Process(target=func, args=args, daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
|
@ -4,7 +4,6 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import torch
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
@ -286,6 +285,8 @@ def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
import torch
|
||||
|
||||
try:
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
|
@ -3,7 +3,6 @@ import time
|
||||
from datetime import datetime
|
||||
|
||||
import dask
|
||||
import torch
|
||||
from dask.distributed import Client
|
||||
from dask.distributed import Future
|
||||
from distributed import LocalCluster
|
||||
@ -61,6 +60,8 @@ def _get_num_threads() -> int:
|
||||
"""Get # of "threads" to use for ML models in an indexing job. By default uses
|
||||
the torch implementation, which returns the # of physical cores on the machine.
|
||||
"""
|
||||
import torch
|
||||
|
||||
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
|
||||
|
||||
|
||||
@ -457,6 +458,8 @@ def update__main() -> None:
|
||||
# needed for CUDA to work with multiprocessing
|
||||
# NOTE: needs to be done on application startup
|
||||
# before any other torch code has been run
|
||||
import torch
|
||||
|
||||
if not DASK_JOB_CLIENT_ENABLED:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
|
@ -1,8 +1,6 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.app_configs import CHUNK_OVERLAP
|
||||
@ -16,10 +14,15 @@ from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
|
||||
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
|
||||
|
||||
|
||||
def extract_blurb(text: str, blurb_size: int) -> str:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_default_tokenizer().tokenize
|
||||
blurb_splitter = SentenceSplitter(
|
||||
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
|
||||
@ -33,11 +36,13 @@ def chunk_large_section(
|
||||
section_link_text: str,
|
||||
document: Document,
|
||||
start_chunk_id: int,
|
||||
tokenizer: AutoTokenizer,
|
||||
tokenizer: "AutoTokenizer",
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
) -> list[DocAwareChunk]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
blurb = extract_blurb(section_text, blurb_size)
|
||||
|
||||
sentence_aware_splitter = SentenceSplitter(
|
||||
@ -155,6 +160,8 @@ def chunk_document(
|
||||
def split_chunk_text_into_mini_chunks(
|
||||
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
|
||||
) -> list[str]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_default_tokenizer().tokenize
|
||||
sentence_aware_splitter = SentenceSplitter(
|
||||
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0
|
||||
|
@ -2,7 +2,7 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import nltk # type:ignore
|
||||
import torch
|
||||
import torch # Import here is fine, API server needs torch anyway and nothing imports main.py
|
||||
import uvicorn
|
||||
from fastapi import APIRouter
|
||||
from fastapi import FastAPI
|
||||
|
@ -1,4 +1,4 @@
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
@ -10,8 +10,11 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
|
||||
def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
|
||||
|
||||
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
|
||||
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
|
||||
It splits up even foreign characters and unicode emojis without using UNK"""
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
Loading…
x
Reference in New Issue
Block a user