mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 03:58:30 +02:00
Memory Reduction (#1092)
This commit is contained in:
@@ -8,13 +8,15 @@ from collections.abc import Callable
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
from typing import Optional
|
||||||
from torch import multiprocessing
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch.multiprocessing import Process
|
||||||
|
|
||||||
JobStatusType = (
|
JobStatusType = (
|
||||||
Literal["error"]
|
Literal["error"]
|
||||||
@@ -30,7 +32,7 @@ class SimpleJob:
|
|||||||
"""Drop in replacement for `dask.distributed.Future`"""
|
"""Drop in replacement for `dask.distributed.Future`"""
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
process: multiprocessing.Process | None = None
|
process: Optional["Process"] = None
|
||||||
|
|
||||||
def cancel(self) -> bool:
|
def cancel(self) -> bool:
|
||||||
return self.release()
|
return self.release()
|
||||||
@@ -87,6 +89,8 @@ class SimpleJobClient:
|
|||||||
|
|
||||||
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
|
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
|
||||||
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
||||||
|
from torch.multiprocessing import Process
|
||||||
|
|
||||||
self._cleanup_completed_jobs()
|
self._cleanup_completed_jobs()
|
||||||
if len(self.jobs) >= self.n_workers:
|
if len(self.jobs) >= self.n_workers:
|
||||||
logger.debug("No available workers to run job")
|
logger.debug("No available workers to run job")
|
||||||
@@ -95,7 +99,7 @@ class SimpleJobClient:
|
|||||||
job_id = self.job_id_counter
|
job_id = self.job_id_counter
|
||||||
self.job_id_counter += 1
|
self.job_id_counter += 1
|
||||||
|
|
||||||
process = multiprocessing.Process(target=func, args=args, daemon=True)
|
process = Process(target=func, args=args, daemon=True)
|
||||||
job = SimpleJob(id=job_id, process=process)
|
job = SimpleJob(id=job_id, process=process)
|
||||||
process.start()
|
process.start()
|
||||||
|
|
||||||
|
@@ -4,7 +4,6 @@ from datetime import datetime
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from datetime import timezone
|
from datetime import timezone
|
||||||
|
|
||||||
import torch
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||||
@@ -286,6 +285,8 @@ def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
|||||||
"""Entrypoint for indexing run when using dask distributed.
|
"""Entrypoint for indexing run when using dask distributed.
|
||||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||||
and mark the attempt as failed."""
|
and mark the attempt as failed."""
|
||||||
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# set the indexing attempt ID so that all log messages from this process
|
# set the indexing attempt ID so that all log messages from this process
|
||||||
# will have it added as a prefix
|
# will have it added as a prefix
|
||||||
|
@@ -3,7 +3,6 @@ import time
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import dask
|
import dask
|
||||||
import torch
|
|
||||||
from dask.distributed import Client
|
from dask.distributed import Client
|
||||||
from dask.distributed import Future
|
from dask.distributed import Future
|
||||||
from distributed import LocalCluster
|
from distributed import LocalCluster
|
||||||
@@ -61,6 +60,8 @@ def _get_num_threads() -> int:
|
|||||||
"""Get # of "threads" to use for ML models in an indexing job. By default uses
|
"""Get # of "threads" to use for ML models in an indexing job. By default uses
|
||||||
the torch implementation, which returns the # of physical cores on the machine.
|
the torch implementation, which returns the # of physical cores on the machine.
|
||||||
"""
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
|
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
|
||||||
|
|
||||||
|
|
||||||
@@ -457,6 +458,8 @@ def update__main() -> None:
|
|||||||
# needed for CUDA to work with multiprocessing
|
# needed for CUDA to work with multiprocessing
|
||||||
# NOTE: needs to be done on application startup
|
# NOTE: needs to be done on application startup
|
||||||
# before any other torch code has been run
|
# before any other torch code has been run
|
||||||
|
import torch
|
||||||
|
|
||||||
if not DASK_JOB_CLIENT_ENABLED:
|
if not DASK_JOB_CLIENT_ENABLED:
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
torch.multiprocessing.set_start_method("spawn")
|
||||||
|
|
||||||
|
@@ -1,8 +1,6 @@
|
|||||||
import abc
|
import abc
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
from llama_index.text_splitter import SentenceSplitter
|
|
||||||
from transformers import AutoTokenizer # type:ignore
|
|
||||||
|
|
||||||
from danswer.configs.app_configs import BLURB_SIZE
|
from danswer.configs.app_configs import BLURB_SIZE
|
||||||
from danswer.configs.app_configs import CHUNK_OVERLAP
|
from danswer.configs.app_configs import CHUNK_OVERLAP
|
||||||
@@ -16,10 +14,15 @@ from danswer.search.search_nlp_models import get_default_tokenizer
|
|||||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import AutoTokenizer # type:ignore
|
||||||
|
|
||||||
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
|
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
|
||||||
|
|
||||||
|
|
||||||
def extract_blurb(text: str, blurb_size: int) -> str:
|
def extract_blurb(text: str, blurb_size: int) -> str:
|
||||||
|
from llama_index.text_splitter import SentenceSplitter
|
||||||
|
|
||||||
token_count_func = get_default_tokenizer().tokenize
|
token_count_func = get_default_tokenizer().tokenize
|
||||||
blurb_splitter = SentenceSplitter(
|
blurb_splitter = SentenceSplitter(
|
||||||
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
|
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
|
||||||
@@ -33,11 +36,13 @@ def chunk_large_section(
|
|||||||
section_link_text: str,
|
section_link_text: str,
|
||||||
document: Document,
|
document: Document,
|
||||||
start_chunk_id: int,
|
start_chunk_id: int,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: "AutoTokenizer",
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = CHUNK_SIZE,
|
||||||
chunk_overlap: int = CHUNK_OVERLAP,
|
chunk_overlap: int = CHUNK_OVERLAP,
|
||||||
blurb_size: int = BLURB_SIZE,
|
blurb_size: int = BLURB_SIZE,
|
||||||
) -> list[DocAwareChunk]:
|
) -> list[DocAwareChunk]:
|
||||||
|
from llama_index.text_splitter import SentenceSplitter
|
||||||
|
|
||||||
blurb = extract_blurb(section_text, blurb_size)
|
blurb = extract_blurb(section_text, blurb_size)
|
||||||
|
|
||||||
sentence_aware_splitter = SentenceSplitter(
|
sentence_aware_splitter = SentenceSplitter(
|
||||||
@@ -155,6 +160,8 @@ def chunk_document(
|
|||||||
def split_chunk_text_into_mini_chunks(
|
def split_chunk_text_into_mini_chunks(
|
||||||
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
|
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
|
from llama_index.text_splitter import SentenceSplitter
|
||||||
|
|
||||||
token_count_func = get_default_tokenizer().tokenize
|
token_count_func = get_default_tokenizer().tokenize
|
||||||
sentence_aware_splitter = SentenceSplitter(
|
sentence_aware_splitter = SentenceSplitter(
|
||||||
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0
|
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0
|
||||||
|
@@ -2,7 +2,7 @@ from typing import Any
|
|||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import nltk # type:ignore
|
import nltk # type:ignore
|
||||||
import torch
|
import torch # Import here is fine, API server needs torch anyway and nothing imports main.py
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
from transformers import AutoTokenizer # type:ignore
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.models import QueryFlow
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
@@ -10,8 +10,11 @@ from danswer.utils.logger import setup_logger
|
|||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import AutoTokenizer # type:ignore
|
||||||
|
|
||||||
def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
|
|
||||||
|
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
|
||||||
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
|
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
|
||||||
It splits up even foreign characters and unicode emojis without using UNK"""
|
It splits up even foreign characters and unicode emojis without using UNK"""
|
||||||
tokenized_text = tokenizer.tokenize(text)
|
tokenized_text = tokenizer.tokenize(text)
|
||||||
|
Reference in New Issue
Block a user