mirror of
https://github.com/ollama/ollama.git
synced 2025-04-09 20:29:23 +02:00
76 lines
1.4 KiB
Go
76 lines
1.4 KiB
Go
package bert_test
|
|
|
|
import (
|
|
"encoding/json"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/model"
|
|
)
|
|
|
|
func blob(t *testing.T, tag string) string {
|
|
t.Helper()
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
p := filepath.Join(home, ".ollama", "models")
|
|
manifestBytes, err := os.ReadFile(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "all-minilm", tag))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var manifest struct {
|
|
Layers []struct {
|
|
MediaType string `json:"mediaType"`
|
|
Digest string `json:"digest"`
|
|
}
|
|
}
|
|
|
|
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
var digest string
|
|
for _, layer := range manifest.Layers {
|
|
if layer.MediaType == "application/vnd.ollama.image.model" {
|
|
digest = layer.Digest
|
|
break
|
|
}
|
|
}
|
|
|
|
if digest == "" {
|
|
t.Fatal("no model layer found")
|
|
}
|
|
|
|
return filepath.Join(p, "blobs", strings.ReplaceAll(digest, ":", "-"))
|
|
}
|
|
|
|
func TestEmbedding(t *testing.T) {
|
|
m, err := model.New(blob(t, "latest"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
text, err := os.ReadFile(filepath.Join("..", "testdata", "war-and-peace.txt"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
inputIDs, err := m.(model.TextProcessor).Encode(string(text))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
logit, err := model.Forward(m, model.WithInputIDs(inputIDs))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
t.Log(ml.Dump(logit))
|
|
}
|