implement loading ggml lora adapters through the modelfile

This commit is contained in:
Michael Yang
2023-08-03 17:16:05 -07:00
parent d791df75dd
commit 6de5d032e1
5 changed files with 65 additions and 13 deletions

View File

@@ -136,7 +136,7 @@ type llamaHyperparameters struct {
FileType
}
func newLlama(model string, opts api.Options) (*llama, error) {
func newLlama(model string, adapters []string, opts api.Options) (*llama, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
}
@@ -161,6 +161,12 @@ func newLlama(model string, opts api.Options) (*llama, error) {
params.embedding = C.bool(llm.EmbeddingOnly)
params.rope_freq_base = C.float(llm.RopeFrequencyBase)
params.rope_freq_scale = C.float(llm.RopeFrequencyScale)
if len(adapters) > 0 && llm.UseMMap {
log.Printf("must disable mmap to use lora adapters")
params.use_mmap = C.bool(false)
}
llm.params = &params
cModel := C.CString(model)
@@ -176,6 +182,15 @@ func newLlama(model string, opts api.Options) (*llama, error) {
return nil, errors.New("failed to create context")
}
for _, adapter := range adapters {
cAdapter := C.CString(adapter)
defer C.free(unsafe.Pointer(cAdapter))
if retval := C.llama_model_apply_lora_from_file(llm.model, cAdapter, nil, C.int(llm.NumThread)); retval != 0 {
return nil, fmt.Errorf("failed to load adapter %s", adapter)
}
}
// warm up the model
bos := []C.llama_token{C.llama_token_bos()}
C.llama_eval(llm.ctx, unsafe.SliceData(bos), C.int(len(bos)), 0, C.int(opts.NumThread))