mirror of
https://github.com/ollama/ollama.git
synced 2025-11-10 21:37:14 +01:00
tests: reduce stress on CPU to 2 models (#12161)
* tests: reduce stress on CPU to 2 models This should avoid flakes due to systems getting overloaded with 3 (or more) models running concurrently * tests: allow slow systems to pass on timeout If a slow system is still streaming a response, and the response will pass validation, don't fail just because the system is slow. * test: unload embedding models more quickly
This commit is contained in:
@@ -121,6 +121,7 @@ func TestMultiModelStress(t *testing.T) {
|
|||||||
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
||||||
targetLoadCount := 0
|
targetLoadCount := 0
|
||||||
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
||||||
|
chooseModels:
|
||||||
for i, model := range chosenModels {
|
for i, model := range chosenModels {
|
||||||
req := &api.GenerateRequest{Model: model}
|
req := &api.GenerateRequest{Model: model}
|
||||||
slog.Info("loading", "model", model)
|
slog.Info("loading", "model", model)
|
||||||
@@ -142,6 +143,13 @@ func TestMultiModelStress(t *testing.T) {
|
|||||||
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
|
||||||
|
for _, m := range models.Models {
|
||||||
|
if m.SizeVRAM == 0 {
|
||||||
|
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
|
||||||
|
break chooseModels
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if targetLoadCount == len(chosenModels) {
|
if targetLoadCount == len(chosenModels) {
|
||||||
|
|||||||
@@ -38,8 +38,9 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
req := api.EmbeddingRequest{
|
req := api.EmbeddingRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: "why is the sky blue?",
|
||||||
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||||
|
|||||||
@@ -502,6 +502,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
|||||||
done <- 0
|
done <- 0
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var response string
|
||||||
|
verify := func() {
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
response = buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !atLeastOne {
|
||||||
|
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-stallTimer.C:
|
case <-stallTimer.C:
|
||||||
if buf.Len() == 0 {
|
if buf.Len() == 0 {
|
||||||
@@ -517,21 +533,14 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
|||||||
if genErr != nil {
|
if genErr != nil {
|
||||||
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||||
}
|
}
|
||||||
// Verify the response contains the expected data
|
verify()
|
||||||
response := buf.String()
|
|
||||||
atLeastOne := false
|
|
||||||
for _, resp := range anyResp {
|
|
||||||
if strings.Contains(strings.ToLower(response), resp) {
|
|
||||||
atLeastOne = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !atLeastOne {
|
|
||||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
|
||||||
}
|
|
||||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("outer test context done while waiting for generate")
|
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||||
|
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||||
|
// if they are still generating valid responses
|
||||||
|
slog.Warn("outer test context done while waiting for generate")
|
||||||
|
verify()
|
||||||
}
|
}
|
||||||
return context
|
return context
|
||||||
}
|
}
|
||||||
@@ -599,6 +608,22 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
|
|||||||
done <- 0
|
done <- 0
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var response string
|
||||||
|
verify := func() {
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
response = buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !atLeastOne {
|
||||||
|
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-stallTimer.C:
|
case <-stallTimer.C:
|
||||||
if buf.Len() == 0 {
|
if buf.Len() == 0 {
|
||||||
@@ -614,23 +639,14 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
|
|||||||
if genErr != nil {
|
if genErr != nil {
|
||||||
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||||
}
|
}
|
||||||
|
verify()
|
||||||
// Verify the response contains the expected data
|
|
||||||
response := buf.String()
|
|
||||||
atLeastOne := false
|
|
||||||
for _, resp := range anyResp {
|
|
||||||
if strings.Contains(strings.ToLower(response), resp) {
|
|
||||||
atLeastOne = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !atLeastOne {
|
|
||||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("outer test context done while waiting for generate")
|
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||||
|
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||||
|
// if they are still generating valid responses
|
||||||
|
slog.Warn("outer test context done while waiting for chat")
|
||||||
|
verify()
|
||||||
}
|
}
|
||||||
return &api.Message{Role: role, Content: buf.String()}
|
return &api.Message{Role: role, Content: buf.String()}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user