ollama/model/bert/model_test.go
Michael Yang cf1dbcfc5a next bert
2025-02-11 16:06:55 -08:00

76 lines
1.4 KiB
Go

package bert_test
import (
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
)
func blob(t *testing.T, tag string) string {
t.Helper()
home, err := os.UserHomeDir()
if err != nil {
t.Fatal(err)
}
p := filepath.Join(home, ".ollama", "models")
manifestBytes, err := os.ReadFile(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "all-minilm", tag))
if err != nil {
t.Fatal(err)
}
var manifest struct {
Layers []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
}
}
if err := json.Unmarshal(manifestBytes, &manifest); err != nil {
t.Fatal(err)
}
var digest string
for _, layer := range manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.model" {
digest = layer.Digest
break
}
}
if digest == "" {
t.Fatal("no model layer found")
}
return filepath.Join(p, "blobs", strings.ReplaceAll(digest, ":", "-"))
}
func TestEmbedding(t *testing.T) {
m, err := model.New(blob(t, "latest"))
if err != nil {
t.Fatal(err)
}
text, err := os.ReadFile(filepath.Join("..", "testdata", "war-and-peace.txt"))
if err != nil {
t.Fatal(err)
}
inputIDs, err := m.(model.TextProcessor).Encode(string(text))
if err != nil {
t.Fatal(err)
}
logit, err := model.Forward(m, model.WithInputIDs(inputIDs))
if err != nil {
t.Fatal(err)
}
t.Log(ml.Dump(logit))
}