This commit is contained in:
jmorganca 2025-03-22 10:15:52 -07:00
parent 4d8dac8ffc
commit caddb1e4cf
3 changed files with 6 additions and 28 deletions

View File

@ -51,7 +51,7 @@ func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBicubic)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
return data, nil

View File

@ -46,13 +46,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return nil, model.ErrNoVisionModel
}
// Decode image
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
// Process image
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
@ -100,38 +98,18 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return result, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
// Handle multimodal inputs
// var except []int
// hiddenState := m.TextModel.TokenEmbedding.Forward(ctx, inputs)
// for _, image := range opts.Multimodal {
// visionOutputs := image.Multimodal.(ml.Tensor)
// // Copy vision outputs into the hidden state
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
// for i := range visionOutputs.Dim(1) {
// except = append(except, image.Index+i)
// }
// }
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {

View File

@ -116,7 +116,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
// Process text inputs
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)