diff --git a/cli/command/swarm/ca.go b/cli/command/swarm/ca.go index 36b1dc9128..a648d3e833 100644 --- a/cli/command/swarm/ca.go +++ b/cli/command/swarm/ca.go @@ -96,7 +96,7 @@ func runCA(ctx context.Context, dockerCli command.Cli, flags *pflag.FlagSet, opt func updateSwarmSpec(spec *swarm.Spec, flags *pflag.FlagSet, opts caOptions) { caCert := opts.rootCACert.Contents() caKey := opts.rootCAKey.Contents() - opts.mergeSwarmSpecCAFlags(spec, flags, caCert) + opts.mergeSwarmSpecCAFlags(spec, flags, &caCert) spec.CAConfig.SigningCACert = caCert spec.CAConfig.SigningCAKey = caKey diff --git a/cli/command/swarm/client_test.go b/cli/command/swarm/client_test.go index c712887018..99f2dcfe7f 100644 --- a/cli/command/swarm/client_test.go +++ b/cli/command/swarm/client_test.go @@ -11,7 +11,7 @@ import ( type fakeClient struct { client.Client infoFunc func() (system.Info, error) - swarmInitFunc func() (string, error) + swarmInitFunc func(req swarm.InitRequest) (string, error) swarmInspectFunc func() (swarm.Swarm, error) nodeInspectFunc func() (swarm.Node, []byte, error) swarmGetUnlockKeyFunc func() (swarm.UnlockKeyResponse, error) @@ -35,9 +35,9 @@ func (cli *fakeClient) NodeInspectWithRaw(context.Context, string) (swarm.Node, return swarm.Node{}, []byte{}, nil } -func (cli *fakeClient) SwarmInit(context.Context, swarm.InitRequest) (string, error) { +func (cli *fakeClient) SwarmInit(_ context.Context, req swarm.InitRequest) (string, error) { if cli.swarmInitFunc != nil { - return cli.swarmInitFunc() + return cli.swarmInitFunc(req) } return "", nil } diff --git a/cli/command/swarm/init_test.go b/cli/command/swarm/init_test.go index 1ed5cfb9b8..562751bfb6 100644 --- a/cli/command/swarm/init_test.go +++ b/cli/command/swarm/init_test.go @@ -4,11 +4,14 @@ import ( "errors" "fmt" "io" + "os" + "path/filepath" "testing" "github.com/docker/cli/internal/test" "github.com/docker/docker/api/types/swarm" "gotest.tools/v3/assert" + is "gotest.tools/v3/assert/cmp" "gotest.tools/v3/golden" ) @@ -16,7 +19,7 @@ func TestSwarmInitErrorOnAPIFailure(t *testing.T) { testCases := []struct { name string flags map[string]string - swarmInitFunc func() (string, error) + swarmInitFunc func(swarm.InitRequest) (string, error) swarmInspectFunc func() (swarm.Swarm, error) swarmGetUnlockKeyFunc func() (swarm.UnlockKeyResponse, error) nodeInspectFunc func() (swarm.Node, []byte, error) @@ -24,14 +27,14 @@ func TestSwarmInitErrorOnAPIFailure(t *testing.T) { }{ { name: "init-failed", - swarmInitFunc: func() (string, error) { + swarmInitFunc: func(swarm.InitRequest) (string, error) { return "", errors.New("error initializing the swarm") }, expectedError: "error initializing the swarm", }, { name: "init-failed-with-ip-choice", - swarmInitFunc: func() (string, error) { + swarmInitFunc: func(swarm.InitRequest) (string, error) { return "", errors.New("could not choose an IP address to advertise") }, expectedError: "could not choose an IP address to advertise - specify one with --advertise-addr", @@ -85,14 +88,14 @@ func TestSwarmInit(t *testing.T) { testCases := []struct { name string flags map[string]string - swarmInitFunc func() (string, error) + swarmInitFunc func(req swarm.InitRequest) (string, error) swarmInspectFunc func() (swarm.Swarm, error) swarmGetUnlockKeyFunc func() (swarm.UnlockKeyResponse, error) nodeInspectFunc func() (swarm.Node, []byte, error) }{ { name: "init", - swarmInitFunc: func() (string, error) { + swarmInitFunc: func(swarm.InitRequest) (string, error) { return "nodeID", nil }, }, @@ -101,7 +104,7 @@ func TestSwarmInit(t *testing.T) { flags: map[string]string{ flagAutolock: "true", }, - swarmInitFunc: func() (string, error) { + swarmInitFunc: func(swarm.InitRequest) (string, error) { return "nodeID", nil }, swarmGetUnlockKeyFunc: func() (swarm.UnlockKeyResponse, error) { @@ -131,3 +134,28 @@ func TestSwarmInit(t *testing.T) { }) } } + +func TestSwarmInitWithExternalCA(t *testing.T) { + cli := test.NewFakeCli(&fakeClient{ + swarmInitFunc: func(req swarm.InitRequest) (string, error) { + if assert.Check(t, is.Len(req.Spec.CAConfig.ExternalCAs, 1)) { + assert.Equal(t, req.Spec.CAConfig.ExternalCAs[0].CACert, cert) + assert.Equal(t, req.Spec.CAConfig.ExternalCAs[0].Protocol, swarm.ExternalCAProtocolCFSSL) + assert.Equal(t, req.Spec.CAConfig.ExternalCAs[0].URL, "https://example.com") + } + return "nodeID", nil + }, + }) + + tempDir := t.TempDir() + certFile := filepath.Join(tempDir, "cert.pem") + err := os.WriteFile(certFile, []byte(cert), 0o644) + assert.NilError(t, err) + + cmd := newInitCommand(cli) + cmd.SetArgs([]string{}) + cmd.SetOut(io.Discard) + cmd.SetErr(io.Discard) + assert.NilError(t, cmd.Flags().Set(flagExternalCA, "protocol=cfssl,url=https://example.com,cacert="+certFile)) + assert.NilError(t, cmd.Execute()) +} diff --git a/cli/command/swarm/opts.go b/cli/command/swarm/opts.go index a29a67147c..41a0244ed4 100644 --- a/cli/command/swarm/opts.go +++ b/cli/command/swarm/opts.go @@ -231,7 +231,7 @@ func addSwarmFlags(flags *pflag.FlagSet, options *swarmOptions) { addSwarmCAFlags(flags, &options.swarmCAOptions) } -func (o *swarmOptions) mergeSwarmSpec(spec *swarm.Spec, flags *pflag.FlagSet, caCert string) { +func (o *swarmOptions) mergeSwarmSpec(spec *swarm.Spec, flags *pflag.FlagSet, caCert *string) { if flags.Changed(flagTaskHistoryLimit) { spec.Orchestration.TaskHistoryRetentionLimit = &o.taskHistoryLimit } @@ -255,20 +255,24 @@ type swarmCAOptions struct { externalCA ExternalCAOption } -func (o *swarmCAOptions) mergeSwarmSpecCAFlags(spec *swarm.Spec, flags *pflag.FlagSet, caCert string) { +func (o *swarmCAOptions) mergeSwarmSpecCAFlags(spec *swarm.Spec, flags *pflag.FlagSet, caCert *string) { if flags.Changed(flagCertExpiry) { spec.CAConfig.NodeCertExpiry = o.nodeCertExpiry } if flags.Changed(flagExternalCA) { spec.CAConfig.ExternalCAs = o.externalCA.Value() - for _, ca := range spec.CAConfig.ExternalCAs { - ca.CACert = caCert + if caCert != nil { + for _, ca := range spec.CAConfig.ExternalCAs { + if ca.CACert == "" { + ca.CACert = *caCert + } + } } } } func (o *swarmOptions) ToSpec(flags *pflag.FlagSet) swarm.Spec { var spec swarm.Spec - o.mergeSwarmSpec(&spec, flags, "") + o.mergeSwarmSpec(&spec, flags, nil) return spec } diff --git a/cli/command/swarm/update.go b/cli/command/swarm/update.go index 50ddc17b1a..2e853a9312 100644 --- a/cli/command/swarm/update.go +++ b/cli/command/swarm/update.go @@ -53,7 +53,7 @@ func runUpdate(ctx context.Context, dockerCli command.Cli, flags *pflag.FlagSet, prevAutoLock := swarmInspect.Spec.EncryptionConfig.AutoLockManagers - opts.mergeSwarmSpec(&swarmInspect.Spec, flags, swarmInspect.ClusterInfo.TLSInfo.TrustRoot) + opts.mergeSwarmSpec(&swarmInspect.Spec, flags, &swarmInspect.ClusterInfo.TLSInfo.TrustRoot) curAutoLock := swarmInspect.Spec.EncryptionConfig.AutoLockManagers