mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-07 02:10:30 +02:00
76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
from typing import Optional
|
|
|
|
import numpy as np
|
|
import tensorflow as tf # type: ignore
|
|
from fastapi import APIRouter
|
|
from transformers import AutoTokenizer # type: ignore
|
|
from transformers import TFDistilBertForSequenceClassification
|
|
|
|
from model_server.constants import MODEL_WARM_UP_STRING
|
|
from model_server.utils import simple_log_function_time
|
|
from shared_configs.configs import INDEXING_ONLY
|
|
from shared_configs.configs import INTENT_MODEL_CONTEXT_SIZE
|
|
from shared_configs.configs import INTENT_MODEL_VERSION
|
|
from shared_configs.model_server_models import IntentRequest
|
|
from shared_configs.model_server_models import IntentResponse
|
|
|
|
|
|
router = APIRouter(prefix="/custom")
|
|
|
|
_INTENT_TOKENIZER: Optional[AutoTokenizer] = None
|
|
_INTENT_MODEL: Optional[TFDistilBertForSequenceClassification] = None
|
|
|
|
|
|
def get_intent_model_tokenizer(
|
|
model_name: str = INTENT_MODEL_VERSION,
|
|
) -> "AutoTokenizer":
|
|
global _INTENT_TOKENIZER
|
|
if _INTENT_TOKENIZER is None:
|
|
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
|
|
return _INTENT_TOKENIZER
|
|
|
|
|
|
def get_local_intent_model(
|
|
model_name: str = INTENT_MODEL_VERSION,
|
|
max_context_length: int = INTENT_MODEL_CONTEXT_SIZE,
|
|
) -> TFDistilBertForSequenceClassification:
|
|
global _INTENT_MODEL
|
|
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
|
|
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
|
model_name
|
|
)
|
|
_INTENT_MODEL.max_seq_length = max_context_length
|
|
return _INTENT_MODEL
|
|
|
|
|
|
def warm_up_intent_model() -> None:
|
|
intent_tokenizer = get_intent_model_tokenizer()
|
|
inputs = intent_tokenizer(
|
|
MODEL_WARM_UP_STRING, return_tensors="tf", truncation=True, padding=True
|
|
)
|
|
get_local_intent_model()(inputs)
|
|
|
|
|
|
@simple_log_function_time()
|
|
def classify_intent(query: str) -> list[float]:
|
|
tokenizer = get_intent_model_tokenizer()
|
|
intent_model = get_local_intent_model()
|
|
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
|
|
|
predictions = intent_model(model_input)[0]
|
|
probabilities = tf.nn.softmax(predictions, axis=-1)
|
|
|
|
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
|
return list(class_percentages.tolist()[0])
|
|
|
|
|
|
@router.post("/intent-model")
|
|
async def process_intent_request(
|
|
intent_request: IntentRequest,
|
|
) -> IntentResponse:
|
|
if INDEXING_ONLY:
|
|
raise RuntimeError("Indexing model server should not call intent endpoint")
|
|
|
|
class_percentages = classify_intent(intent_request.query)
|
|
return IntentResponse(class_probs=class_percentages)
|