mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-12 09:00:53 +02:00
GPU Model Server (#2135)
This commit is contained in:
parent
0530f4283e
commit
1c10f54294
@ -66,22 +66,30 @@ def warm_up_intent_model() -> None:
|
|||||||
MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True
|
MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True
|
||||||
)
|
)
|
||||||
intent_model = get_local_intent_model()
|
intent_model = get_local_intent_model()
|
||||||
intent_model(query_ids=tokens["input_ids"], query_mask=tokens["attention_mask"])
|
device = intent_model.device
|
||||||
|
intent_model(
|
||||||
|
query_ids=tokens["input_ids"].to(device),
|
||||||
|
query_mask=tokens["attention_mask"].to(device),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@simple_log_function_time()
|
@simple_log_function_time()
|
||||||
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||||
intent_model = get_local_intent_model()
|
intent_model = get_local_intent_model()
|
||||||
|
device = intent_model.device
|
||||||
|
|
||||||
outputs = intent_model(
|
outputs = intent_model(
|
||||||
query_ids=tokens["input_ids"], query_mask=tokens["attention_mask"]
|
query_ids=tokens["input_ids"].to(device),
|
||||||
|
query_mask=tokens["attention_mask"].to(device),
|
||||||
)
|
)
|
||||||
|
|
||||||
token_logits = outputs["token_logits"]
|
token_logits = outputs["token_logits"]
|
||||||
intent_logits = outputs["intent_logits"]
|
intent_logits = outputs["intent_logits"]
|
||||||
intent_probabilities = F.softmax(intent_logits, dim=-1).numpy()[0]
|
|
||||||
|
|
||||||
token_probabilities = F.softmax(token_logits, dim=-1).numpy()[0]
|
# Move tensors to CPU before applying softmax and converting to numpy
|
||||||
|
intent_probabilities = F.softmax(intent_logits.cpu(), dim=-1).numpy()[0]
|
||||||
|
token_probabilities = F.softmax(token_logits.cpu(), dim=-1).numpy()[0]
|
||||||
|
|
||||||
# Extract the probabilities for the positive class (index 1) for each token
|
# Extract the probabilities for the positive class (index 1) for each token
|
||||||
token_positive_probs = token_probabilities[:, 1].tolist()
|
token_positive_probs = token_probabilities[:, 1].tolist()
|
||||||
|
|
||||||
|
@ -23,6 +23,8 @@ class HybridClassifier(nn.Module):
|
|||||||
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
self.intent_classifier = nn.Linear(self.distilbert.config.dim, 2)
|
||||||
self.dropout = nn.Dropout(self.distilbert.config.seq_classif_dropout)
|
self.dropout = nn.Dropout(self.distilbert.config.seq_classif_dropout)
|
||||||
|
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query_ids: torch.Tensor,
|
query_ids: torch.Tensor,
|
||||||
@ -51,17 +53,21 @@ class HybridClassifier(nn.Module):
|
|||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
model = cls(**config)
|
model = cls(**config)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.backends.mps.is_available():
|
||||||
|
# Apple silicon GPU
|
||||||
|
device = torch.device("mps")
|
||||||
|
elif torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# No cuda, model most likely just loaded on CPU
|
device = torch.device("cpu")
|
||||||
model.load_state_dict(torch.load(model_path))
|
|
||||||
|
model.load_state_dict(torch.load(model_path, map_location=device))
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
model.device = device
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
# Eval doesn't set requires_grad to False, do it manually to save memory and have faster inference
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user