Remove un-needed imports (#999)

This commit is contained in:
Chris Weaver 2024-01-25 12:10:19 -08:00 committed by GitHub
parent 92628357df
commit e94fd8b022
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 57 additions and 21 deletions

View File

@ -1,4 +1,5 @@
from sentence_transformers import SentenceTransformer # type: ignore
from typing import Optional
from typing import TYPE_CHECKING
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
@ -11,12 +12,16 @@ from danswer.search.models import Embedder
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.utils.logger import setup_logger
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer # type: ignore
logger = setup_logger()
def embed_chunks(
chunks: list[DocAwareChunk],
embedding_model: SentenceTransformer | None = None,
embedding_model: Optional["SentenceTransformer"] = None,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
passage_prefix: str = ASYM_PASSAGE_PREFIX,

View File

@ -1,13 +1,10 @@
import logging
import os
from typing import Optional
from typing import TYPE_CHECKING
import numpy as np
import requests
import tensorflow as tf # type: ignore
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification # type: ignore
from danswer.configs.app_configs import CURRENT_PROCESS_IS_AN_INDEXING_JOB
from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST
@ -33,14 +30,25 @@ logger = setup_logger()
logging.getLogger("transformers").setLevel(logging.ERROR)
_TOKENIZER: None | AutoTokenizer = None
_EMBED_MODEL: None | SentenceTransformer = None
_RERANK_MODELS: None | list[CrossEncoder] = None
_INTENT_TOKENIZER: None | AutoTokenizer = None
_INTENT_MODEL: None | TFDistilBertForSequenceClassification = None
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification # type: ignore
def get_default_tokenizer() -> AutoTokenizer:
_TOKENIZER: Optional["AutoTokenizer"] = None
_EMBED_MODEL: Optional["SentenceTransformer"] = None
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
_INTENT_TOKENIZER: Optional["AutoTokenizer"] = None
_INTENT_MODEL: Optional["TFDistilBertForSequenceClassification"] = None
def get_default_tokenizer() -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _TOKENIZER
if _TOKENIZER is None:
_TOKENIZER = AutoTokenizer.from_pretrained(DOCUMENT_ENCODER_MODEL)
@ -52,7 +60,11 @@ def get_default_tokenizer() -> AutoTokenizer:
def get_local_embedding_model(
model_name: str = DOCUMENT_ENCODER_MODEL,
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> SentenceTransformer:
) -> "SentenceTransformer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from sentence_transformers import SentenceTransformer # type: ignore
global _EMBED_MODEL
if _EMBED_MODEL is None or max_context_length != _EMBED_MODEL.max_seq_length:
logger.info(f"Loading {model_name}")
@ -64,7 +76,11 @@ def get_local_embedding_model(
def get_local_reranking_model_ensemble(
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
) -> list[CrossEncoder]:
) -> list["CrossEncoder"]:
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from sentence_transformers import CrossEncoder
global _RERANK_MODELS
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
_RERANK_MODELS = []
@ -76,7 +92,13 @@ def get_local_reranking_model_ensemble(
return _RERANK_MODELS
def get_intent_model_tokenizer(model_name: str = INTENT_MODEL_VERSION) -> AutoTokenizer:
def get_intent_model_tokenizer(
model_name: str = INTENT_MODEL_VERSION,
) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
@ -86,7 +108,11 @@ def get_intent_model_tokenizer(model_name: str = INTENT_MODEL_VERSION) -> AutoTo
def get_local_intent_model(
model_name: str = INTENT_MODEL_VERSION,
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
) -> TFDistilBertForSequenceClassification:
) -> "TFDistilBertForSequenceClassification":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import TFDistilBertForSequenceClassification # type: ignore
global _INTENT_MODEL
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
@ -137,7 +163,7 @@ class EmbeddingModel:
model_server_url + "/encoder/bi-encoder-embed" if model_server_url else None
)
def load_model(self) -> SentenceTransformer | None:
def load_model(self) -> Optional["SentenceTransformer"]:
if self.embed_server_endpoint:
return None
@ -190,7 +216,7 @@ class CrossEncoderEnsembleModel:
else None
)
def load_model(self) -> list[CrossEncoder] | None:
def load_model(self) -> list["CrossEncoder"] | None:
if self.rerank_server_endpoint:
return None
@ -242,7 +268,7 @@ class IntentModel:
model_server_url + "/custom/intent-model" if model_server_url else None
)
def load_model(self) -> SentenceTransformer | None:
def load_model(self) -> Optional["SentenceTransformer"]:
if self.intent_server_endpoint:
return None
@ -254,6 +280,10 @@ class IntentModel:
self,
query: str,
) -> list[float]:
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
import tensorflow as tf # type: ignore
if self.intent_server_endpoint:
intent_request = IntentRequest(query=query)

View File

@ -1,5 +1,4 @@
import numpy as np
import tensorflow as tf # type:ignore
from fastapi import APIRouter
from danswer.search.search_nlp_models import get_intent_model_tokenizer
@ -13,6 +12,8 @@ router = APIRouter(prefix="/custom")
@log_function_time(print_only=True)
def classify_intent(query: str) -> list[float]:
import tensorflow as tf # type:ignore
tokenizer = get_intent_model_tokenizer()
intent_model = get_local_intent_model()
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)