danswer/backend/model_server/custom_models.py
2024-04-10 23:13:22 -07:00

76 lines
2.5 KiB
Python

from typing import Optional
import numpy as np
import tensorflow as tf # type: ignore
from fastapi import APIRouter
from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INTENT_MODEL_CONTEXT_SIZE
from shared_configs.configs import INTENT_MODEL_VERSION
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
router = APIRouter(prefix="/custom")
_INTENT_TOKENIZER: Optional[AutoTokenizer] = None
_INTENT_MODEL: Optional[TFDistilBertForSequenceClassification] = None
def get_intent_model_tokenizer(
model_name: str = INTENT_MODEL_VERSION,
) -> "AutoTokenizer":
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
return _INTENT_TOKENIZER
def get_local_intent_model(
model_name: str = INTENT_MODEL_VERSION,
max_context_length: int = INTENT_MODEL_CONTEXT_SIZE,
) -> TFDistilBertForSequenceClassification:
global _INTENT_MODEL
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
model_name
)
_INTENT_MODEL.max_seq_length = max_context_length
return _INTENT_MODEL
def warm_up_intent_model() -> None:
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
MODEL_WARM_UP_STRING, return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)
@simple_log_function_time()
def classify_intent(query: str) -> list[float]:
tokenizer = get_intent_model_tokenizer()
intent_model = get_local_intent_model()
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
predictions = intent_model(model_input)[0]
probabilities = tf.nn.softmax(predictions, axis=-1)
class_percentages = np.round(probabilities.numpy() * 100, 2)
return list(class_percentages.tolist()[0])
@router.post("/intent-model")
async def process_intent_request(
intent_request: IntentRequest,
) -> IntentResponse:
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
class_percentages = classify_intent(intent_request.query)
return IntentResponse(class_probs=class_percentages)