mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-03 09:28:25 +02:00
Remove un-needed imports (#999)
This commit is contained in:
parent
92628357df
commit
e94fd8b022
@ -1,4 +1,5 @@
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
@ -11,12 +12,16 @@ from danswer.search.models import Embedder
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def embed_chunks(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedding_model: SentenceTransformer | None = None,
|
||||
embedding_model: Optional["SentenceTransformer"] = None,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
passage_prefix: str = ASYM_PASSAGE_PREFIX,
|
||||
|
@ -1,13 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow as tf # type: ignore
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import CURRENT_PROCESS_IS_AN_INDEXING_JOB
|
||||
from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST
|
||||
@ -33,14 +30,25 @@ logger = setup_logger()
|
||||
logging.getLogger("transformers").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
_TOKENIZER: None | AutoTokenizer = None
|
||||
_EMBED_MODEL: None | SentenceTransformer = None
|
||||
_RERANK_MODELS: None | list[CrossEncoder] = None
|
||||
_INTENT_TOKENIZER: None | AutoTokenizer = None
|
||||
_INTENT_MODEL: None | TFDistilBertForSequenceClassification = None
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
||||
|
||||
|
||||
def get_default_tokenizer() -> AutoTokenizer:
|
||||
_TOKENIZER: Optional["AutoTokenizer"] = None
|
||||
_EMBED_MODEL: Optional["SentenceTransformer"] = None
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
_INTENT_TOKENIZER: Optional["AutoTokenizer"] = None
|
||||
_INTENT_MODEL: Optional["TFDistilBertForSequenceClassification"] = None
|
||||
|
||||
|
||||
def get_default_tokenizer() -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER is None:
|
||||
_TOKENIZER = AutoTokenizer.from_pretrained(DOCUMENT_ENCODER_MODEL)
|
||||
@ -52,7 +60,11 @@ def get_default_tokenizer() -> AutoTokenizer:
|
||||
def get_local_embedding_model(
|
||||
model_name: str = DOCUMENT_ENCODER_MODEL,
|
||||
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> SentenceTransformer:
|
||||
) -> "SentenceTransformer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
global _EMBED_MODEL
|
||||
if _EMBED_MODEL is None or max_context_length != _EMBED_MODEL.max_seq_length:
|
||||
logger.info(f"Loading {model_name}")
|
||||
@ -64,7 +76,11 @@ def get_local_embedding_model(
|
||||
def get_local_reranking_model_ensemble(
|
||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||
) -> list[CrossEncoder]:
|
||||
) -> list["CrossEncoder"]:
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
global _RERANK_MODELS
|
||||
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
|
||||
_RERANK_MODELS = []
|
||||
@ -76,7 +92,13 @@ def get_local_reranking_model_ensemble(
|
||||
return _RERANK_MODELS
|
||||
|
||||
|
||||
def get_intent_model_tokenizer(model_name: str = INTENT_MODEL_VERSION) -> AutoTokenizer:
|
||||
def get_intent_model_tokenizer(
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
|
||||
@ -86,7 +108,11 @@ def get_intent_model_tokenizer(model_name: str = INTENT_MODEL_VERSION) -> AutoTo
|
||||
def get_local_intent_model(
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
|
||||
) -> TFDistilBertForSequenceClassification:
|
||||
) -> "TFDistilBertForSequenceClassification":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
||||
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
|
||||
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
||||
@ -137,7 +163,7 @@ class EmbeddingModel:
|
||||
model_server_url + "/encoder/bi-encoder-embed" if model_server_url else None
|
||||
)
|
||||
|
||||
def load_model(self) -> SentenceTransformer | None:
|
||||
def load_model(self) -> Optional["SentenceTransformer"]:
|
||||
if self.embed_server_endpoint:
|
||||
return None
|
||||
|
||||
@ -190,7 +216,7 @@ class CrossEncoderEnsembleModel:
|
||||
else None
|
||||
)
|
||||
|
||||
def load_model(self) -> list[CrossEncoder] | None:
|
||||
def load_model(self) -> list["CrossEncoder"] | None:
|
||||
if self.rerank_server_endpoint:
|
||||
return None
|
||||
|
||||
@ -242,7 +268,7 @@ class IntentModel:
|
||||
model_server_url + "/custom/intent-model" if model_server_url else None
|
||||
)
|
||||
|
||||
def load_model(self) -> SentenceTransformer | None:
|
||||
def load_model(self) -> Optional["SentenceTransformer"]:
|
||||
if self.intent_server_endpoint:
|
||||
return None
|
||||
|
||||
@ -254,6 +280,10 @@ class IntentModel:
|
||||
self,
|
||||
query: str,
|
||||
) -> list[float]:
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
import tensorflow as tf # type: ignore
|
||||
|
||||
if self.intent_server_endpoint:
|
||||
intent_request = IntentRequest(query=query)
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf # type:ignore
|
||||
from fastapi import APIRouter
|
||||
|
||||
from danswer.search.search_nlp_models import get_intent_model_tokenizer
|
||||
@ -13,6 +12,8 @@ router = APIRouter(prefix="/custom")
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def classify_intent(query: str) -> list[float]:
|
||||
import tensorflow as tf # type:ignore
|
||||
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
intent_model = get_local_intent_model()
|
||||
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user