From 6c39bc1f60f7b22c2e2bc1acaee60694b15b9dcd Mon Sep 17 00:00:00 2001 From: Sebastiaan van Stijn Date: Tue, 27 Dec 2022 16:24:23 +0100 Subject: [PATCH] opts: use strings.Cut for handling key/value pairs Signed-off-by: Sebastiaan van Stijn --- cli/command/container/opts_test.go | 4 +- opts/config.go | 20 +++++----- opts/env.go | 11 +++--- opts/file.go | 11 +++--- opts/gpus.go | 16 ++++---- opts/hosts.go | 29 ++++++++------ opts/mount.go | 63 ++++++++++++++---------------- opts/network.go | 41 +++++++++---------- opts/opts.go | 47 +++++++++++----------- opts/opts_test.go | 8 ++-- opts/parse.go | 32 ++++++--------- opts/port.go | 25 ++++++------ opts/secret.go | 21 +++++----- opts/throttledevice.go | 22 ++++++----- opts/weightdevice.go | 14 +++---- 15 files changed, 175 insertions(+), 189 deletions(-) diff --git a/cli/command/container/opts_test.go b/cli/command/container/opts_test.go index 7e78562a48..6421f80a70 100644 --- a/cli/command/container/opts_test.go +++ b/cli/command/container/opts_test.go @@ -649,8 +649,8 @@ func TestRunFlagsParseShmSize(t *testing.T) { func TestParseRestartPolicy(t *testing.T) { invalids := map[string]string{ - "always:2:3": "invalid restart policy format", - "on-failure:invalid": "maximum retry count must be an integer", + "always:2:3": "invalid restart policy format: maximum retry count must be an integer", + "on-failure:invalid": "invalid restart policy format: maximum retry count must be an integer", } valids := map[string]container.RestartPolicy{ "": {}, diff --git a/opts/config.go b/opts/config.go index 40bc13e251..3be0fa93dd 100644 --- a/opts/config.go +++ b/opts/config.go @@ -40,25 +40,23 @@ func (o *ConfigOpt) Set(value string) error { } for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - key := strings.ToLower(parts[0]) - - if len(parts) != 2 { + key, val, ok := strings.Cut(field, "=") + if !ok || key == "" { return fmt.Errorf("invalid field '%s' must be a key=value pair", field) } - value := parts[1] - switch key { + // TODO(thaJeztah): these options should not be case-insensitive. + switch strings.ToLower(key) { case "source", "src": - options.ConfigName = value + options.ConfigName = val case "target": - options.File.Name = value + options.File.Name = val case "uid": - options.File.UID = value + options.File.UID = val case "gid": - options.File.GID = value + options.File.GID = val case "mode": - m, err := strconv.ParseUint(value, 0, 32) + m, err := strconv.ParseUint(val, 0, 32) if err != nil { return fmt.Errorf("invalid mode specified: %v", err) } diff --git a/opts/env.go b/opts/env.go index d21c8ccbef..214d6f4400 100644 --- a/opts/env.go +++ b/opts/env.go @@ -16,15 +16,16 @@ import ( // // The only validation here is to check if name is empty, per #25099 func ValidateEnv(val string) (string, error) { - arr := strings.SplitN(val, "=", 2) - if arr[0] == "" { + k, _, hasValue := strings.Cut(val, "=") + if k == "" { return "", errors.New("invalid environment variable: " + val) } - if len(arr) > 1 { + if hasValue { + // val contains a "=" (but value may be an empty string) return val, nil } - if envVal, ok := os.LookupEnv(arr[0]); ok { - return arr[0] + "=" + envVal, nil + if envVal, ok := os.LookupEnv(k); ok { + return k + "=" + envVal, nil } return val, nil } diff --git a/opts/file.go b/opts/file.go index 2346cc1670..72b90e117f 100644 --- a/opts/file.go +++ b/opts/file.go @@ -46,10 +46,10 @@ func parseKeyValueFile(filename string, emptyFn func(string) (string, bool)) ([] currentLine++ // line is not empty, and not starting with '#' if len(line) > 0 && !strings.HasPrefix(line, "#") { - data := strings.SplitN(line, "=", 2) + variable, value, hasValue := strings.Cut(line, "=") // trim the front of a variable, but nothing else - variable := strings.TrimLeft(data[0], whiteSpaces) + variable = strings.TrimLeft(variable, whiteSpaces) if strings.ContainsAny(variable, whiteSpaces) { return []string{}, ErrBadKey{fmt.Sprintf("variable '%s' contains whitespaces", variable)} } @@ -57,18 +57,17 @@ func parseKeyValueFile(filename string, emptyFn func(string) (string, bool)) ([] return []string{}, ErrBadKey{fmt.Sprintf("no variable name on line '%s'", line)} } - if len(data) > 1 { + if hasValue { // pass the value through, no trimming - lines = append(lines, fmt.Sprintf("%s=%s", variable, data[1])) + lines = append(lines, variable+"="+value) } else { - var value string var present bool if emptyFn != nil { value, present = emptyFn(line) } if present { // if only a pass-through variable is given, clean it up. - lines = append(lines, fmt.Sprintf("%s=%s", strings.TrimSpace(line), value)) + lines = append(lines, strings.TrimSpace(variable)+"="+value) } } } diff --git a/opts/gpus.go b/opts/gpus.go index 8796a805d4..93bf939786 100644 --- a/opts/gpus.go +++ b/opts/gpus.go @@ -38,14 +38,13 @@ func (o *GpuOpts) Set(value string) error { seen := map[string]struct{}{} // Set writable as the default for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - key := parts[0] + key, val, withValue := strings.Cut(field, "=") if _, ok := seen[key]; ok { return fmt.Errorf("gpu request key '%s' can be specified only once", key) } seen[key] = struct{}{} - if len(parts) == 1 { + if !withValue { seen["count"] = struct{}{} req.Count, err = parseCount(key) if err != nil { @@ -54,21 +53,20 @@ func (o *GpuOpts) Set(value string) error { continue } - value := parts[1] switch key { case "driver": - req.Driver = value + req.Driver = val case "count": - req.Count, err = parseCount(value) + req.Count, err = parseCount(val) if err != nil { return err } case "device": - req.DeviceIDs = strings.Split(value, ",") + req.DeviceIDs = strings.Split(val, ",") case "capabilities": - req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")} + req.Capabilities = [][]string{append(strings.Split(val, ","), "gpu")} case "options": - r := csv.NewReader(strings.NewReader(value)) + r := csv.NewReader(strings.NewReader(val)) optFields, err := r.Read() if err != nil { return errors.Wrap(err, "failed to read gpu options") diff --git a/opts/hosts.go b/opts/hosts.go index d59421b308..7cdd1218f7 100644 --- a/opts/hosts.go +++ b/opts/hosts.go @@ -33,6 +33,8 @@ const ( ) // ValidateHost validates that the specified string is a valid host and returns it. +// +// TODO(thaJeztah): ValidateHost appears to be unused; deprecate it. func ValidateHost(val string) (string, error) { host := strings.TrimSpace(val) // The empty string means default and is not handled by parseDockerDaemonHost @@ -69,18 +71,19 @@ func ParseHost(defaultToTLS bool, val string) (string, error) { // parseDockerDaemonHost parses the specified address and returns an address that will be used as the host. // Depending of the address specified, this may return one of the global Default* strings defined in hosts.go. func parseDockerDaemonHost(addr string) (string, error) { - addrParts := strings.SplitN(addr, "://", 2) - if len(addrParts) == 1 && addrParts[0] != "" { - addrParts = []string{"tcp", addrParts[0]} + proto, host, hasProto := strings.Cut(addr, "://") + if !hasProto && proto != "" { + host = proto + proto = "tcp" } - switch addrParts[0] { + switch proto { case "tcp": - return ParseTCPAddr(addrParts[1], defaultTCPHost) + return ParseTCPAddr(host, defaultTCPHost) case "unix": - return parseSimpleProtoAddr("unix", addrParts[1], defaultUnixSocket) + return parseSimpleProtoAddr(proto, host, defaultUnixSocket) case "npipe": - return parseSimpleProtoAddr("npipe", addrParts[1], defaultNamedPipe) + return parseSimpleProtoAddr(proto, host, defaultNamedPipe) case "fd": return addr, nil case "ssh": @@ -160,16 +163,18 @@ func ParseTCPAddr(tryAddr string, defaultAddr string) (string, error) { // ValidateExtraHost validates that the specified string is a valid extrahost and returns it. // ExtraHost is in the form of name:ip where the ip has to be a valid ip (IPv4 or IPv6). +// +// TODO(thaJeztah): remove client-side validation, and delegate to the API server. func ValidateExtraHost(val string) (string, error) { // allow for IPv6 addresses in extra hosts by only splitting on first ":" - arr := strings.SplitN(val, ":", 2) - if len(arr) != 2 || len(arr[0]) == 0 { + k, v, ok := strings.Cut(val, ":") + if !ok || k == "" { return "", fmt.Errorf("bad format for add-host: %q", val) } // Skip IPaddr validation for "host-gateway" string - if arr[1] != hostGatewayName { - if _, err := ValidateIPAddress(arr[1]); err != nil { - return "", fmt.Errorf("invalid IP address in add-host: %q", arr[1]) + if v != hostGatewayName { + if _, err := ValidateIPAddress(v); err != nil { + return "", fmt.Errorf("invalid IP address in add-host: %q", v) } } return val, nil diff --git a/opts/mount.go b/opts/mount.go index 7ffc3acc93..2b531127eb 100644 --- a/opts/mount.go +++ b/opts/mount.go @@ -56,21 +56,21 @@ func (m *MountOpt) Set(value string) error { } setValueOnMap := func(target map[string]string, value string) { - parts := strings.SplitN(value, "=", 2) - if len(parts) == 1 { - target[value] = "" - } else { - target[parts[0]] = parts[1] + k, v, _ := strings.Cut(value, "=") + if k != "" { + target[k] = v } } mount.Type = mounttypes.TypeVolume // default to volume mounts // Set writable as the default for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - key := strings.ToLower(parts[0]) + key, val, ok := strings.Cut(field, "=") - if len(parts) == 1 { + // TODO(thaJeztah): these options should not be case-insensitive. + key = strings.ToLower(key) + + if !ok { switch key { case "readonly", "ro": mount.ReadOnly = true @@ -81,64 +81,61 @@ func (m *MountOpt) Set(value string) error { case "bind-nonrecursive": bindOptions().NonRecursive = true continue + default: + return fmt.Errorf("invalid field '%s' must be a key=value pair", field) } } - if len(parts) != 2 { - return fmt.Errorf("invalid field '%s' must be a key=value pair", field) - } - - value := parts[1] switch key { case "type": - mount.Type = mounttypes.Type(strings.ToLower(value)) + mount.Type = mounttypes.Type(strings.ToLower(val)) case "source", "src": - mount.Source = value - if strings.HasPrefix(value, "."+string(filepath.Separator)) || value == "." { - if abs, err := filepath.Abs(value); err == nil { + mount.Source = val + if strings.HasPrefix(val, "."+string(filepath.Separator)) || val == "." { + if abs, err := filepath.Abs(val); err == nil { mount.Source = abs } } case "target", "dst", "destination": - mount.Target = value + mount.Target = val case "readonly", "ro": - mount.ReadOnly, err = strconv.ParseBool(value) + mount.ReadOnly, err = strconv.ParseBool(val) if err != nil { - return fmt.Errorf("invalid value for %s: %s", key, value) + return fmt.Errorf("invalid value for %s: %s", key, val) } case "consistency": - mount.Consistency = mounttypes.Consistency(strings.ToLower(value)) + mount.Consistency = mounttypes.Consistency(strings.ToLower(val)) case "bind-propagation": - bindOptions().Propagation = mounttypes.Propagation(strings.ToLower(value)) + bindOptions().Propagation = mounttypes.Propagation(strings.ToLower(val)) case "bind-nonrecursive": - bindOptions().NonRecursive, err = strconv.ParseBool(value) + bindOptions().NonRecursive, err = strconv.ParseBool(val) if err != nil { - return fmt.Errorf("invalid value for %s: %s", key, value) + return fmt.Errorf("invalid value for %s: %s", key, val) } case "volume-nocopy": - volumeOptions().NoCopy, err = strconv.ParseBool(value) + volumeOptions().NoCopy, err = strconv.ParseBool(val) if err != nil { - return fmt.Errorf("invalid value for volume-nocopy: %s", value) + return fmt.Errorf("invalid value for volume-nocopy: %s", val) } case "volume-label": - setValueOnMap(volumeOptions().Labels, value) + setValueOnMap(volumeOptions().Labels, val) case "volume-driver": - volumeOptions().DriverConfig.Name = value + volumeOptions().DriverConfig.Name = val case "volume-opt": if volumeOptions().DriverConfig.Options == nil { volumeOptions().DriverConfig.Options = make(map[string]string) } - setValueOnMap(volumeOptions().DriverConfig.Options, value) + setValueOnMap(volumeOptions().DriverConfig.Options, val) case "tmpfs-size": - sizeBytes, err := units.RAMInBytes(value) + sizeBytes, err := units.RAMInBytes(val) if err != nil { - return fmt.Errorf("invalid value for %s: %s", key, value) + return fmt.Errorf("invalid value for %s: %s", key, val) } tmpfsOptions().SizeBytes = sizeBytes case "tmpfs-mode": - ui64, err := strconv.ParseUint(value, 8, 32) + ui64, err := strconv.ParseUint(val, 8, 32) if err != nil { - return fmt.Errorf("invalid value for %s: %s", key, value) + return fmt.Errorf("invalid value for %s: %s", key, val) } tmpfsOptions().Mode = os.FileMode(ui64) default: diff --git a/opts/network.go b/opts/network.go index ce7370ee0e..12c3977b1b 100644 --- a/opts/network.go +++ b/opts/network.go @@ -48,34 +48,33 @@ func (n *NetworkOpt) Set(value string) error { netOpt.Aliases = []string{} for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - - if len(parts) < 2 { + // TODO(thaJeztah): these options should not be case-insensitive. + key, val, ok := strings.Cut(strings.ToLower(field), "=") + if !ok || key == "" { return fmt.Errorf("invalid field %s", field) } - key := strings.TrimSpace(strings.ToLower(parts[0])) - value := strings.TrimSpace(strings.ToLower(parts[1])) + key = strings.TrimSpace(key) + val = strings.TrimSpace(val) switch key { case networkOptName: - netOpt.Target = value + netOpt.Target = val case networkOptAlias: - netOpt.Aliases = append(netOpt.Aliases, value) + netOpt.Aliases = append(netOpt.Aliases, val) case networkOptIPv4Address: - netOpt.IPv4Address = value + netOpt.IPv4Address = val case networkOptIPv6Address: - netOpt.IPv6Address = value + netOpt.IPv6Address = val case driverOpt: - key, value, err = parseDriverOpt(value) - if err == nil { - if netOpt.DriverOpts == nil { - netOpt.DriverOpts = make(map[string]string) - } - netOpt.DriverOpts[key] = value - } else { + key, val, err = parseDriverOpt(val) + if err != nil { return err } + if netOpt.DriverOpts == nil { + netOpt.DriverOpts = make(map[string]string) + } + netOpt.DriverOpts[key] = val default: return fmt.Errorf("invalid field key %s", key) } @@ -116,11 +115,13 @@ func (n *NetworkOpt) NetworkMode() string { } func parseDriverOpt(driverOpt string) (string, string, error) { - parts := strings.SplitN(driverOpt, "=", 2) - if len(parts) != 2 { + // TODO(thaJeztah): these options should not be case-insensitive. + // TODO(thaJeztah): should value be converted to lowercase as well, or only the key? + key, value, ok := strings.Cut(strings.ToLower(driverOpt), "=") + if !ok || key == "" { return "", "", fmt.Errorf("invalid key value pair format in driver options") } - key := strings.TrimSpace(strings.ToLower(parts[0])) - value := strings.TrimSpace(strings.ToLower(parts[1])) + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) return key, value, nil } diff --git a/opts/opts.go b/opts/opts.go index 7cb87b72dd..4e7790f0fc 100644 --- a/opts/opts.go +++ b/opts/opts.go @@ -165,12 +165,8 @@ func (opts *MapOpts) Set(value string) error { } value = v } - vals := strings.SplitN(value, "=", 2) - if len(vals) == 1 { - opts.values[vals[0]] = "" - } else { - opts.values[vals[0]] = vals[1] - } + k, v, _ := strings.Cut(value, "=") + opts.values[k] = v return nil } @@ -277,16 +273,16 @@ func validateDomain(val string) (string, error) { // // TODO discuss if quotes (and other special characters) should be valid or invalid for keys // TODO discuss if leading/trailing whitespace in keys should be preserved (and valid) -func ValidateLabel(val string) (string, error) { - arr := strings.SplitN(val, "=", 2) - key := strings.TrimLeft(arr[0], whiteSpaces) +func ValidateLabel(value string) (string, error) { + key, _, _ := strings.Cut(value, "=") + key = strings.TrimLeft(key, whiteSpaces) if key == "" { - return "", fmt.Errorf("invalid label '%s': empty name", val) + return "", fmt.Errorf("invalid label '%s': empty name", value) } if strings.ContainsAny(key, whiteSpaces) { return "", fmt.Errorf("label '%s' contains whitespaces", key) } - return val, nil + return value, nil } // ValidateSysctl validates a sysctl and returns it. @@ -305,20 +301,19 @@ func ValidateSysctl(val string) (string, error) { "net.", "fs.mqueue.", } - arr := strings.Split(val, "=") - if len(arr) < 2 { - return "", fmt.Errorf("sysctl '%s' is not whitelisted", val) + k, _, ok := strings.Cut(val, "=") + if !ok || k == "" { + return "", fmt.Errorf("sysctl '%s' is not allowed", val) } - if validSysctlMap[arr[0]] { + if validSysctlMap[k] { return val, nil } - for _, vp := range validSysctlPrefixes { - if strings.HasPrefix(arr[0], vp) { + if strings.HasPrefix(k, vp) { return val, nil } } - return "", fmt.Errorf("sysctl '%s' is not whitelisted", val) + return "", fmt.Errorf("sysctl '%s' is not allowed", val) } // FilterOpt is a flag type for validating filters @@ -347,11 +342,12 @@ func (o *FilterOpt) Set(value string) error { if !strings.Contains(value, "=") { return errors.New("bad format of filter (expected name=value)") } - f := strings.SplitN(value, "=", 2) - name := strings.ToLower(strings.TrimSpace(f[0])) - value = strings.TrimSpace(f[1]) + name, val, _ := strings.Cut(value, "=") - o.filter.Add(name, value) + // TODO(thaJeztah): these options should not be case-insensitive. + name = strings.ToLower(strings.TrimSpace(name)) + val = strings.TrimSpace(val) + o.filter.Add(name, val) return nil } @@ -411,10 +407,14 @@ func ParseLink(val string) (string, string, error) { if val == "" { return "", "", fmt.Errorf("empty string specified for links") } - arr := strings.Split(val, ":") + // We expect two parts, but restrict to three to allow detecting invalid formats. + arr := strings.SplitN(val, ":", 3) + + // TODO(thaJeztah): clean up this logic!! if len(arr) > 2 { return "", "", fmt.Errorf("bad format for links: %s", val) } + // TODO(thaJeztah): this should trim the "/" prefix as well?? if len(arr) == 1 { return val, val, nil } @@ -422,6 +422,7 @@ func ParseLink(val string) (string, string, error) { // from an already created container and the format is not `foo:bar` // but `/foo:/c1/bar` if strings.HasPrefix(arr[0], "/") { + // TODO(thaJeztah): clean up this logic!! _, alias := path.Split(arr[1]) return arr[0][1:], alias, nil } diff --git a/opts/opts_test.go b/opts/opts_test.go index f8b6267165..1c91df5a4e 100644 --- a/opts/opts_test.go +++ b/opts/opts_test.go @@ -301,12 +301,12 @@ func TestValidateLabel(t *testing.T) { } func sampleValidator(val string) (string, error) { - allowedKeys := map[string]string{"max-size": "1", "max-file": "2"} - vals := strings.Split(val, "=") - if allowedKeys[vals[0]] != "" { + allowedKeys := map[string]string{"valid-option": "1", "valid-option2": "2"} + k, _, _ := strings.Cut(val, "=") + if allowedKeys[k] != "" { return val, nil } - return "", fmt.Errorf("invalid key %s", vals[0]) + return "", fmt.Errorf("invalid key %s", k) } func TestNamedListOpts(t *testing.T) { diff --git a/opts/parse.go b/opts/parse.go index 4012c461fb..017577e4bf 100644 --- a/opts/parse.go +++ b/opts/parse.go @@ -41,12 +41,8 @@ func readKVStrings(files []string, override []string, emptyFn func(string) (stri func ConvertKVStringsToMap(values []string) map[string]string { result := make(map[string]string, len(values)) for _, value := range values { - kv := strings.SplitN(value, "=", 2) - if len(kv) == 1 { - result[kv[0]] = "" - } else { - result[kv[0]] = kv[1] - } + k, v, _ := strings.Cut(value, "=") + result[k] = v } return result @@ -62,11 +58,11 @@ func ConvertKVStringsToMap(values []string) map[string]string { func ConvertKVStringsToMapWithNil(values []string) map[string]*string { result := make(map[string]*string, len(values)) for _, value := range values { - kv := strings.SplitN(value, "=", 2) - if len(kv) == 1 { - result[kv[0]] = nil + k, v, ok := strings.Cut(value, "=") + if !ok { + result[k] = nil } else { - result[kv[0]] = &kv[1] + result[k] = &v } } @@ -81,21 +77,15 @@ func ParseRestartPolicy(policy string) (container.RestartPolicy, error) { return p, nil } - parts := strings.Split(policy, ":") - - if len(parts) > 2 { - return p, fmt.Errorf("invalid restart policy format") - } - if len(parts) == 2 { - count, err := strconv.Atoi(parts[1]) + k, v, _ := strings.Cut(policy, ":") + if v != "" { + count, err := strconv.Atoi(v) if err != nil { - return p, fmt.Errorf("maximum retry count must be an integer") + return p, fmt.Errorf("invalid restart policy format: maximum retry count must be an integer") } - p.MaximumRetryCount = count } - p.Name = parts[0] - + p.Name = k return p, nil } diff --git a/opts/port.go b/opts/port.go index c0814dbd9a..fe41cdd288 100644 --- a/opts/port.go +++ b/opts/port.go @@ -42,36 +42,33 @@ func (p *PortOpt) Set(value string) error { pConfig := swarm.PortConfig{} for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - if len(parts) != 2 { + // TODO(thaJeztah): these options should not be case-insensitive. + key, val, ok := strings.Cut(strings.ToLower(field), "=") + if !ok || key == "" { return fmt.Errorf("invalid field %s", field) } - - key := strings.ToLower(parts[0]) - value := strings.ToLower(parts[1]) - switch key { case portOptProtocol: - if value != string(swarm.PortConfigProtocolTCP) && value != string(swarm.PortConfigProtocolUDP) && value != string(swarm.PortConfigProtocolSCTP) { - return fmt.Errorf("invalid protocol value %s", value) + if val != string(swarm.PortConfigProtocolTCP) && val != string(swarm.PortConfigProtocolUDP) && val != string(swarm.PortConfigProtocolSCTP) { + return fmt.Errorf("invalid protocol value %s", val) } - pConfig.Protocol = swarm.PortConfigProtocol(value) + pConfig.Protocol = swarm.PortConfigProtocol(val) case portOptMode: - if value != string(swarm.PortConfigPublishModeIngress) && value != string(swarm.PortConfigPublishModeHost) { - return fmt.Errorf("invalid publish mode value %s", value) + if val != string(swarm.PortConfigPublishModeIngress) && val != string(swarm.PortConfigPublishModeHost) { + return fmt.Errorf("invalid publish mode value %s", val) } - pConfig.PublishMode = swarm.PortConfigPublishMode(value) + pConfig.PublishMode = swarm.PortConfigPublishMode(val) case portOptTargetPort: - tPort, err := strconv.ParseUint(value, 10, 16) + tPort, err := strconv.ParseUint(val, 10, 16) if err != nil { return err } pConfig.TargetPort = uint32(tPort) case portOptPublishedPort: - pPort, err := strconv.ParseUint(value, 10, 16) + pPort, err := strconv.ParseUint(val, 10, 16) if err != nil { return err } diff --git a/opts/secret.go b/opts/secret.go index fabc62c01a..750dbe4f30 100644 --- a/opts/secret.go +++ b/opts/secret.go @@ -40,25 +40,22 @@ func (o *SecretOpt) Set(value string) error { } for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - key := strings.ToLower(parts[0]) - - if len(parts) != 2 { + key, val, ok := strings.Cut(field, "=") + if !ok || key == "" { return fmt.Errorf("invalid field '%s' must be a key=value pair", field) } - - value := parts[1] - switch key { + // TODO(thaJeztah): these options should not be case-insensitive. + switch strings.ToLower(key) { case "source", "src": - options.SecretName = value + options.SecretName = val case "target": - options.File.Name = value + options.File.Name = val case "uid": - options.File.UID = value + options.File.UID = val case "gid": - options.File.GID = value + options.File.GID = val case "mode": - m, err := strconv.ParseUint(value, 0, 32) + m, err := strconv.ParseUint(val, 0, 32) if err != nil { return fmt.Errorf("invalid mode specified: %v", err) } diff --git a/opts/throttledevice.go b/opts/throttledevice.go index 773092bc0b..9fb788433b 100644 --- a/opts/throttledevice.go +++ b/opts/throttledevice.go @@ -14,14 +14,15 @@ type ValidatorThrottleFctType func(val string) (*blkiodev.ThrottleDevice, error) // ValidateThrottleBpsDevice validates that the specified string has a valid device-rate format. func ValidateThrottleBpsDevice(val string) (*blkiodev.ThrottleDevice, error) { - split := strings.SplitN(val, ":", 2) - if len(split) != 2 { + k, v, ok := strings.Cut(val, ":") + if !ok || k == "" { return nil, fmt.Errorf("bad format: %s", val) } - if !strings.HasPrefix(split[0], "/dev/") { + // TODO(thaJeztah): should we really validate this on the client? + if !strings.HasPrefix(k, "/dev/") { return nil, fmt.Errorf("bad format for device path: %s", val) } - rate, err := units.RAMInBytes(split[1]) + rate, err := units.RAMInBytes(v) if err != nil { return nil, fmt.Errorf("invalid rate for device: %s. The correct format is :[]. Number must be a positive integer. Unit is optional and can be kb, mb, or gb", val) } @@ -30,26 +31,27 @@ func ValidateThrottleBpsDevice(val string) (*blkiodev.ThrottleDevice, error) { } return &blkiodev.ThrottleDevice{ - Path: split[0], + Path: v, Rate: uint64(rate), }, nil } // ValidateThrottleIOpsDevice validates that the specified string has a valid device-rate format. func ValidateThrottleIOpsDevice(val string) (*blkiodev.ThrottleDevice, error) { - split := strings.SplitN(val, ":", 2) - if len(split) != 2 { + k, v, ok := strings.Cut(val, ":") + if !ok || k == "" { return nil, fmt.Errorf("bad format: %s", val) } - if !strings.HasPrefix(split[0], "/dev/") { + // TODO(thaJeztah): should we really validate this on the client? + if !strings.HasPrefix(k, "/dev/") { return nil, fmt.Errorf("bad format for device path: %s", val) } - rate, err := strconv.ParseUint(split[1], 10, 64) + rate, err := strconv.ParseUint(v, 10, 64) if err != nil { return nil, fmt.Errorf("invalid rate for device: %s. The correct format is :. Number must be a positive integer", val) } - return &blkiodev.ThrottleDevice{Path: split[0], Rate: rate}, nil + return &blkiodev.ThrottleDevice{Path: k, Rate: rate}, nil } // ThrottledeviceOpt defines a map of ThrottleDevices diff --git a/opts/weightdevice.go b/opts/weightdevice.go index 7d217ceef9..3077e3da7f 100644 --- a/opts/weightdevice.go +++ b/opts/weightdevice.go @@ -13,14 +13,15 @@ type ValidatorWeightFctType func(val string) (*blkiodev.WeightDevice, error) // ValidateWeightDevice validates that the specified string has a valid device-weight format. func ValidateWeightDevice(val string) (*blkiodev.WeightDevice, error) { - split := strings.SplitN(val, ":", 2) - if len(split) != 2 { + k, v, ok := strings.Cut(val, ":") + if !ok || k == "" { return nil, fmt.Errorf("bad format: %s", val) } - if !strings.HasPrefix(split[0], "/dev/") { + // TODO(thaJeztah): should we really validate this on the client? + if !strings.HasPrefix(k, "/dev/") { return nil, fmt.Errorf("bad format for device path: %s", val) } - weight, err := strconv.ParseUint(split[1], 10, 16) + weight, err := strconv.ParseUint(v, 10, 16) if err != nil { return nil, fmt.Errorf("invalid weight for device: %s", val) } @@ -29,7 +30,7 @@ func ValidateWeightDevice(val string) (*blkiodev.WeightDevice, error) { } return &blkiodev.WeightDevice{ - Path: split[0], + Path: k, Weight: uint16(weight), }, nil } @@ -42,9 +43,8 @@ type WeightdeviceOpt struct { // NewWeightdeviceOpt creates a new WeightdeviceOpt func NewWeightdeviceOpt(validator ValidatorWeightFctType) WeightdeviceOpt { - values := []*blkiodev.WeightDevice{} return WeightdeviceOpt{ - values: values, + values: []*blkiodev.WeightDevice{}, validator: validator, } }