diff --git a/server/modelpath.go b/server/modelpath.go index 13b26eeb6..651ee1be4 100644 --- a/server/modelpath.go +++ b/server/modelpath.go @@ -67,6 +67,20 @@ func ParseModelPath(name string) ModelPath { return mp } +var errModelPathInvalid = errors.New("invalid model path") + +func (mp ModelPath) Validate() error { + if mp.Repository == "" { + return fmt.Errorf("%w: model repository name is required", errModelPathInvalid) + } + + if strings.Contains(mp.Tag, ":") { + return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid) + } + + return nil +} + func (mp ModelPath) GetNamespaceRepository() string { return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository) } diff --git a/server/routes.go b/server/routes.go index 03bf8acee..bc8ea8043 100644 --- a/server/routes.go +++ b/server/routes.go @@ -416,8 +416,8 @@ func CreateModelHandler(c *gin.Context) { return } - if strings.Count(req.Name, ":") > 1 { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "':' (colon) is not allowed in tag names"}) + if err := ParseModelPath(req.Name).Validate(); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } @@ -645,6 +645,11 @@ func CopyModelHandler(c *gin.Context) { return } + if err := ParseModelPath(req.Destination).Validate(); err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + if err := CopyModel(req.Source, req.Destination); err != nil { if os.IsNotExist(err) { c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})