mirror of
https://github.com/ollama/ollama.git
synced 2025-03-29 11:11:47 +01:00
add batch embeddings
This commit is contained in:
parent
8e30eb26bd
commit
ad7e641815
api
docs
integration
llm
server
10
api/types.go
10
api/types.go
@ -159,15 +159,17 @@ type Runner struct {
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
PromptBatch []string `json:"prompt_batch,omitempty"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Embedding []float64 `json:"embedding,omitempty"`
|
||||
EmbeddingBatch [][]float64 `json:"embedding_batch,omitempty"`
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
|
33
docs/api.md
33
docs/api.md
@ -1010,7 +1010,8 @@ Generate embeddings from a model
|
||||
### Parameters
|
||||
|
||||
- `model`: name of model to generate embeddings from
|
||||
- `prompt`: text to generate embeddings for
|
||||
- `prompt`: string to generate the embedding for
|
||||
- `prompts`: array of strings to generate a batch of embeddings for
|
||||
|
||||
Advanced parameters:
|
||||
|
||||
@ -1038,3 +1039,33 @@ curl http://localhost:11434/api/embeddings -d '{
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
#### Request (batch)
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/embeddings -d '{
|
||||
"model": "all-minilm",
|
||||
"prompt_batch": [
|
||||
"Here is an article about llamas...",
|
||||
"Here is another article about llamas..."
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"embedding_batch": [
|
||||
[
|
||||
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
||||
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
||||
],
|
||||
[
|
||||
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
||||
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
||||
],
|
||||
]
|
||||
}
|
||||
```
|
||||
|
64
integration/embedding_test.go
Normal file
64
integration/embedding_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestAllMiniLMEmbedding(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
res := EmbeddingTestHelper(ctx, t, &http.Client{}, req)
|
||||
|
||||
if len(res.Embedding) != 384 {
|
||||
t.Fatalf("Expected 384 floats to be returned, got %v", len(res.Embedding))
|
||||
}
|
||||
|
||||
if res.Embedding[0] != 0.146763876080513 {
|
||||
t.Fatalf("Expected first embedding float to be 0.146763876080513, got %v", res.Embedding[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompts: []string{"why is the sky blue?", "why is the sky blue?"},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
|
||||
res := EmbeddingTestHelper(ctx, t, &http.Client{}, req)
|
||||
|
||||
if len(res.Embeddings) != 2 {
|
||||
t.Fatal("Expected 2 embeddings to be returned")
|
||||
}
|
||||
|
||||
if len(res.Embeddings[0]) != 384 {
|
||||
t.Fatalf("Expected first embedding to have 384 floats, got %v", len(res.Embeddings[0]))
|
||||
}
|
||||
|
||||
if res.Embeddings[0][0] != 0.146763876080513 && res.Embeddings[1][0] != 0.146763876080513 {
|
||||
t.Fatalf("Expected first embedding floats to be 0.146763876080513, got %v, %v", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ package integration
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -24,6 +25,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -285,6 +287,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
// Generate a set of requests
|
||||
// By default each request uses orca-mini as the model
|
||||
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
stream := false
|
||||
return []api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
@ -336,3 +339,83 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingTestHelper(ctx context.Context, t *testing.T, client *http.Client, req api.EmbeddingRequest) api.EmbeddingResponse {
|
||||
|
||||
// TODO maybe stuff in an init routine?
|
||||
lifecycle.InitLogging()
|
||||
|
||||
requestJSON, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Error serializing request: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
defer serverProcMutex.Unlock()
|
||||
if t.Failed() {
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
err = os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
scheme, testEndpoint := GetTestEndpoint()
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
serverProcMutex.Lock()
|
||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate log file: %s", err)
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
assert.NoError(t, StartServer(ctx, testEndpoint))
|
||||
}
|
||||
|
||||
err = PullIfMissing(ctx, client, scheme, testEndpoint, req.Model)
|
||||
if err != nil {
|
||||
t.Fatalf("Error pulling model: %v", err)
|
||||
}
|
||||
|
||||
// Make the request and get the response
|
||||
httpReq, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/embeddings", bytes.NewReader(requestJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating request: %v", err)
|
||||
}
|
||||
|
||||
// Set the content type for the request
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Make the request with the HTTP client
|
||||
response, err := client.Do(httpReq.WithContext(ctx))
|
||||
if err != nil {
|
||||
t.Fatalf("Error making request: %v", err)
|
||||
}
|
||||
defer response.Body.Close()
|
||||
body, err := io.ReadAll(response.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, response.StatusCode, 200, string(body))
|
||||
|
||||
// Verify the response is valid JSON
|
||||
var res api.EmbeddingResponse
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
assert.NoError(t, err, body)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
55
llm/ext_server/server.cpp
vendored
55
llm/ext_server/server.cpp
vendored
@ -3209,54 +3209,27 @@ int main(int argc, char **argv) {
|
||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||
svr.Post("/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||
{
|
||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||
const json body = json::parse(req.body);
|
||||
json prompt;
|
||||
if (body.count("content") != 0)
|
||||
{
|
||||
prompt = body["content"];
|
||||
}
|
||||
else
|
||||
{
|
||||
prompt = "";
|
||||
|
||||
const int id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(id);
|
||||
llama.request_completion(id, {{"prompt", body["contents"]}}, false, true, -1);
|
||||
|
||||
task_result recv = llama.queue_results.recv(id);
|
||||
llama.queue_results.remove_waiting_task_id(id);
|
||||
|
||||
json embeddings = json::array();
|
||||
for (auto & elem : recv.result_json["results"]) {
|
||||
embeddings.push_back(json_value(elem, "embedding", json::array()));
|
||||
}
|
||||
|
||||
json image_data;
|
||||
if (body.count("image_data") != 0) {
|
||||
image_data = body["image_data"];
|
||||
}
|
||||
else
|
||||
{
|
||||
image_data = "";
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
|
||||
|
||||
// get the result
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
llama.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
// send the result
|
||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||
json result = json{{"embeddings", embeddings}};
|
||||
return res.set_content(result.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
||||
// "Bus error: 10" - this is on macOS, it does not crash on Linux
|
||||
//std::thread t2([&]()
|
||||
/*{
|
||||
bool running = true;
|
||||
while (running)
|
||||
{
|
||||
running = llama.update_slots();
|
||||
}
|
||||
}*/
|
||||
//);
|
||||
|
||||
if (sparams.n_threads_http < 1) {
|
||||
// +2 threads for monitoring endpoints
|
||||
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
||||
|
@ -32,7 +32,7 @@ type LlamaServer interface {
|
||||
Ping(ctx context.Context) error
|
||||
WaitUntilRunning(ctx context.Context) error
|
||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||
Embedding(ctx context.Context, prompt string) ([]float64, error)
|
||||
Embeddings(ctx context.Context, prompt []string) ([][]float64, error)
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
@ -736,15 +736,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
type EmbeddingsRequest struct {
|
||||
Contents []string `json:"contents"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
type EmbeddingsResponse struct {
|
||||
Embeddings [][]float64 `json:"embeddings"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||
func (s *llmServer) Embeddings(ctx context.Context, prompts []string) ([][]float64, error) {
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
return nil, err
|
||||
@ -758,12 +758,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||
data, err := json.Marshal(EmbeddingsRequest{Contents: prompts})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embeddings", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||
}
|
||||
@ -780,17 +780,19 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
||||
return nil, fmt.Errorf("error reading embed response: %w", err)
|
||||
}
|
||||
|
||||
fmt.Println("embeddings response", string(body))
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm encode error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var embedding EmbeddingResponse
|
||||
var embedding EmbeddingsResponse
|
||||
if err := json.Unmarshal(body, &embedding); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return embedding.Embedding, nil
|
||||
return embedding.Embeddings, nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
|
@ -403,23 +403,39 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// an empty request loads the model
|
||||
if req.Prompt == "" {
|
||||
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
|
||||
return
|
||||
}
|
||||
switch {
|
||||
// single embedding
|
||||
case len(req.Prompt) > 0:
|
||||
embeddings, err := runner.llama.Embeddings(c.Request.Context(), []string{req.Prompt})
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
resp := api.EmbeddingResponse{Embedding: embeddings[0]}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
|
||||
resp := api.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
// batch embeddings
|
||||
case len(req.PromptBatch) > 0:
|
||||
embeddings, err := runner.llama.Embeddings(c.Request.Context(), req.PromptBatch)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("batch embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.EmbeddingResponse{EmbeddingBatch: embeddings}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
|
||||
// empty prompt loads the model
|
||||
default:
|
||||
if req.PromptBatch != nil {
|
||||
c.JSON(http.StatusOK, api.EmbeddingResponse{EmbeddingBatch: [][]float64{}})
|
||||
} else {
|
||||
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
|
||||
}
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) PullModelHandler(c *gin.Context) {
|
||||
|
@ -530,7 +530,7 @@ type mockLlm struct {
|
||||
pingResp error
|
||||
waitResp error
|
||||
completionResp error
|
||||
embeddingResp []float64
|
||||
embeddingResp [][]float64
|
||||
embeddingRespErr error
|
||||
tokenizeResp []int
|
||||
tokenizeRespErr error
|
||||
@ -546,7 +546,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
|
||||
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
return s.completionResp
|
||||
}
|
||||
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||
func (s *mockLlm) Embeddings(ctx context.Context, prompts []string) ([][]float64, error) {
|
||||
return s.embeddingResp, s.embeddingRespErr
|
||||
}
|
||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user