add batch embeddings

This commit is contained in:
jmorganca 2024-04-14 20:53:20 -04:00
parent 8e30eb26bd
commit ad7e641815
8 changed files with 243 additions and 72 deletions

@ -159,15 +159,17 @@ type Runner struct {
}
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
PromptBatch []string `json:"prompt_batch,omitempty"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options"`
}
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
Embedding []float64 `json:"embedding,omitempty"`
EmbeddingBatch [][]float64 `json:"embedding_batch,omitempty"`
}
type CreateRequest struct {

@ -1010,7 +1010,8 @@ Generate embeddings from a model
### Parameters
- `model`: name of model to generate embeddings from
- `prompt`: text to generate embeddings for
- `prompt`: string to generate the embedding for
- `prompts`: array of strings to generate a batch of embeddings for
Advanced parameters:
@ -1038,3 +1039,33 @@ curl http://localhost:11434/api/embeddings -d '{
]
}
```
#### Request (batch)
```shell
curl http://localhost:11434/api/embeddings -d '{
"model": "all-minilm",
"prompt_batch": [
"Here is an article about llamas...",
"Here is another article about llamas..."
]
}'
```
#### Response
```json
{
"embedding_batch": [
[
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
],
[
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
],
]
}
```

@ -0,0 +1,64 @@
//go:build integration
package integration
import (
"context"
"net/http"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestAllMiniLMEmbedding(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
res := EmbeddingTestHelper(ctx, t, &http.Client{}, req)
if len(res.Embedding) != 384 {
t.Fatalf("Expected 384 floats to be returned, got %v", len(res.Embedding))
}
if res.Embedding[0] != 0.146763876080513 {
t.Fatalf("Expected first embedding float to be 0.146763876080513, got %v", res.Embedding[0])
}
}
func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompts: []string{"why is the sky blue?", "why is the sky blue?"},
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
res := EmbeddingTestHelper(ctx, t, &http.Client{}, req)
if len(res.Embeddings) != 2 {
t.Fatal("Expected 2 embeddings to be returned")
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("Expected first embedding to have 384 floats, got %v", len(res.Embeddings[0]))
}
if res.Embeddings[0][0] != 0.146763876080513 && res.Embeddings[1][0] != 0.146763876080513 {
t.Fatalf("Expected first embedding floats to be 0.146763876080513, got %v, %v", res.Embeddings[0][0], res.Embeddings[1][0])
}
}

@ -5,6 +5,7 @@ package integration
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
@ -24,6 +25,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -285,6 +287,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
// Generate a set of requests
// By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
stream := false
return []api.GenerateRequest{
{
Model: "orca-mini",
@ -336,3 +339,83 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
}
}
func EmbeddingTestHelper(ctx context.Context, t *testing.T, client *http.Client, req api.EmbeddingRequest) api.EmbeddingResponse {
// TODO maybe stuff in an init routine?
lifecycle.InitLogging()
requestJSON, err := json.Marshal(req)
if err != nil {
t.Fatalf("Error serializing request: %v", err)
}
defer func() {
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
defer serverProcMutex.Unlock()
if t.Failed() {
fp, err := os.Open(lifecycle.ServerLogFile)
if err != nil {
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
return
}
data, err := io.ReadAll(fp)
if err != nil {
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
return
}
slog.Warn("SERVER LOG FOLLOWS")
os.Stderr.Write(data)
slog.Warn("END OF SERVER")
}
err = os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
}
}
}()
scheme, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
assert.NoError(t, StartServer(ctx, testEndpoint))
}
err = PullIfMissing(ctx, client, scheme, testEndpoint, req.Model)
if err != nil {
t.Fatalf("Error pulling model: %v", err)
}
// Make the request and get the response
httpReq, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/embeddings", bytes.NewReader(requestJSON))
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
// Set the content type for the request
httpReq.Header.Set("Content-Type", "application/json")
// Make the request with the HTTP client
response, err := client.Do(httpReq.WithContext(ctx))
if err != nil {
t.Fatalf("Error making request: %v", err)
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, response.StatusCode, 200, string(body))
// Verify the response is valid JSON
var res api.EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {
assert.NoError(t, err, body)
}
return res
}

@ -3209,54 +3209,27 @@ int main(int argc, char **argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
svr.Post("/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
json prompt;
if (body.count("content") != 0)
{
prompt = body["content"];
}
else
{
prompt = "";
const int id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id);
llama.request_completion(id, {{"prompt", body["contents"]}}, false, true, -1);
task_result recv = llama.queue_results.recv(id);
llama.queue_results.remove_waiting_task_id(id);
json embeddings = json::array();
for (auto & elem : recv.result_json["results"]) {
embeddings.push_back(json_value(elem, "embedding", json::array()));
}
json image_data;
if (body.count("image_data") != 0) {
image_data = body["image_data"];
}
else
{
image_data = "";
}
// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
json result = json{{"embeddings", embeddings}};
return res.set_content(result.dump(), "application/json; charset=utf-8");
});
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
// "Bus error: 10" - this is on macOS, it does not crash on Linux
//std::thread t2([&]()
/*{
bool running = true;
while (running)
{
running = llama.update_slots();
}
}*/
//);
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);

@ -32,7 +32,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Embeddings(ctx context.Context, prompt []string) ([][]float64, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@ -736,15 +736,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("max retries exceeded")
}
type EmbeddingRequest struct {
Content string `json:"content"`
type EmbeddingsRequest struct {
Contents []string `json:"contents"`
}
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
type EmbeddingsResponse struct {
Embeddings [][]float64 `json:"embeddings"`
}
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
func (s *llmServer) Embeddings(ctx context.Context, prompts []string) ([][]float64, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
@ -758,12 +758,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: prompt})
data, err := json.Marshal(EmbeddingsRequest{Contents: prompts})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embeddings", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err)
}
@ -780,17 +780,19 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("error reading embed response: %w", err)
}
fmt.Println("embeddings response", string(body))
if resp.StatusCode >= 400 {
log.Printf("llm encode error: %s", body)
return nil, fmt.Errorf("%s", body)
}
var embedding EmbeddingResponse
var embedding EmbeddingsResponse
if err := json.Unmarshal(body, &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return embedding.Embedding, nil
return embedding.Embeddings, nil
}
type TokenizeRequest struct {

@ -403,23 +403,39 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
// an empty request loads the model
if req.Prompt == "" {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
return
}
switch {
// single embedding
case len(req.Prompt) > 0:
embeddings, err := runner.llama.Embeddings(c.Request.Context(), []string{req.Prompt})
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
resp := api.EmbeddingResponse{Embedding: embeddings[0]}
c.JSON(http.StatusOK, resp)
resp := api.EmbeddingResponse{
Embedding: embedding,
// batch embeddings
case len(req.PromptBatch) > 0:
embeddings, err := runner.llama.Embeddings(c.Request.Context(), req.PromptBatch)
if err != nil {
slog.Info(fmt.Sprintf("batch embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
resp := api.EmbeddingResponse{EmbeddingBatch: embeddings}
c.JSON(http.StatusOK, resp)
// empty prompt loads the model
default:
if req.PromptBatch != nil {
c.JSON(http.StatusOK, api.EmbeddingResponse{EmbeddingBatch: [][]float64{}})
} else {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
}
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) PullModelHandler(c *gin.Context) {

@ -530,7 +530,7 @@ type mockLlm struct {
pingResp error
waitResp error
completionResp error
embeddingResp []float64
embeddingResp [][]float64
embeddingRespErr error
tokenizeResp []int
tokenizeRespErr error
@ -546,7 +546,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
func (s *mockLlm) Embeddings(ctx context.Context, prompts []string) ([][]float64, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {