GPU Model Server (#2135)

This commit is contained in:
Yuhong Sun 2024-08-14 11:04:28 -07:00 committed by GitHub
parent 0530f4283e
commit 1c10f54294
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 11 deletions

View File

@ -66,22 +66,30 @@ def warm_up_intent_model() -> None:
MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True
)
intent_model = get_local_intent_model()
intent_model(query_ids=tokens["input_ids"], query_mask=tokens["attention_mask"])
device = intent_model.device
intent_model(
query_ids=tokens["input_ids"].to(device),
query_mask=tokens["attention_mask"].to(device),
)
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
device = intent_model.device
outputs = intent_model(
query_ids=tokens["input_ids"], query_mask=tokens["attention_mask"]
query_ids=tokens["input_ids"].to(device),
query_mask=tokens["attention_mask"].to(device),
)
token_logits = outputs["token_logits"]
intent_logits = outputs["intent_logits"]
intent_probabilities = F.softmax(intent_logits, dim=-1).numpy()[0]
token_probabilities = F.softmax(token_logits, dim=-1).numpy()[0]
# Move tensors to CPU before applying softmax and converting to numpy
intent_probabilities = F.softmax(intent_logits.cpu(), dim=-1).numpy()[0]
token_probabilities = F.softmax(token_logits.cpu(), dim=-1).numpy()[0]
# Extract the probabilities for the positive class (index 1) for each token
token_positive_probs = token_probabilities[:, 1].tolist()

View File

@ -23,6 +23,8 @@ class HybridClassifier(nn.Module):
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
self.dropout = nn.Dropout(self.distilbert.config.seq_classif_dropout)
self.device = torch.device("cpu")
def forward(
self,
query_ids: torch.Tensor,
@ -51,17 +53,21 @@ class HybridClassifier(nn.Module):
config = json.load(f)
model = cls(**config)
if torch.cuda.is_available():
if torch.backends.mps.is_available():
# Apple silicon GPU
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
else:
# No cuda, model most likely just loaded on CPU
model.load_state_dict(torch.load(model_path))
device = torch.device("cpu")
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.device = device
model.eval()
# Eval doesn't set requires_grad to False, do it manually to save memory and have faster inference
for param in model.parameters():
param.requires_grad = False