Skip to content

Commit 2088bc3

Browse files
committed
api: Validate provided server group IDs
Don't blindly accept the requested server group ID. Signed-off-by: Stephen Finucane <[email protected]>
1 parent 4f2b197 commit 2088bc3

File tree

4 files changed

+82
-6
lines changed

4 files changed

+82
-6
lines changed

pkg/clients/compute.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ type ComputeClient interface {
6363
DeleteAttachedInterface(serverID, portID string) error
6464

6565
ListServerGroups() ([]servergroups.ServerGroup, error)
66+
GetServerGroup(serverGroupID string) (*servergroups.ServerGroup, error)
6667
}
6768

6869
type computeClient struct{ client *gophercloud.ServiceClient }
@@ -162,6 +163,16 @@ func (c computeClient) ListServerGroups() ([]servergroups.ServerGroup, error) {
162163
return servergroups.ExtractServerGroups(allPages)
163164
}
164165

166+
func (c computeClient) GetServerGroup(serverGroupID string) (*servergroups.ServerGroup, error) {
167+
var serverGroup servergroups.ServerGroup
168+
mc := metrics.NewMetricPrometheusContext("server_group", "get")
169+
err := servergroups.Get(c.client, serverGroupID).ExtractInto(&serverGroup)
170+
if mc.ObserveRequestIgnoreNotFound(err) != nil {
171+
return nil, err
172+
}
173+
return &serverGroup, nil
174+
}
175+
165176
type computeErrorClient struct{ error }
166177

167178
// NewComputeErrorClient returns a ComputeClient in which every method returns the given error.
@@ -204,3 +215,7 @@ func (e computeErrorClient) DeleteAttachedInterface(_, _ string) error {
204215
func (e computeErrorClient) ListServerGroups() ([]servergroups.ServerGroup, error) {
205216
return nil, e.error
206217
}
218+
219+
func (e computeErrorClient) GetServerGroup(_ string) (*servergroups.ServerGroup, error) {
220+
return nil, e.error
221+
}

pkg/clients/mock/compute.go

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/cloud/services/compute/instance.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,11 +604,27 @@ func (s *Service) getImageID(imageUUID, imageName string) (string, error) {
604604

605605
func (s *Service) getServerGroupID(serverGroupID string, serverGroupFilter *infrav1.ServerGroupFilter) (string, error) {
606606
if serverGroupFilter == nil {
607-
return serverGroupID, nil
607+
if serverGroupID == "" {
608+
return "", nil
609+
}
610+
611+
// only fallback to the legacy value if a filter wasn't provided
612+
serverGroupFilter = &infrav1.ServerGroupFilter{
613+
ID: serverGroupID,
614+
}
608615
}
609616

617+
// if we have an ID, use that
610618
if serverGroupFilter.ID != "" {
611-
return serverGroupFilter.ID, nil
619+
serverGroup, err := s.getComputeClient().GetServerGroup(serverGroupFilter.ID)
620+
if err != nil {
621+
return "", err
622+
}
623+
if serverGroupFilter.Name != "" && serverGroupFilter.Name != serverGroup.Name {
624+
// this is super dumb and no sensible person will do this, but we should do the correct thing
625+
return "", fmt.Errorf("found server with ID %s but name %s does not match", serverGroupFilter.ID, serverGroupFilter.Name)
626+
}
627+
return serverGroup.ID, nil
612628
}
613629

614630
// otherwise fallback to looking up by name, which is slower

pkg/cloud/services/compute/instance_test.go

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,19 @@ func TestService_getServerGroupID(t *testing.T) {
136136
serverGroupID: serverGroupID1,
137137
want: serverGroupID1,
138138
expect: func(m *mock.MockComputeClientMockRecorder) {
139+
m.GetServerGroup(serverGroupID1).Return(
140+
&servergroups.ServerGroup{ID: serverGroupID1, Name: "test-server-group"},
141+
nil)
139142
},
140143
wantErr: false,
141144
},
142145
{
143146
testName: "Return server group ID from filter if only filter (with ID) given",
144147
serverGroupFilter: &infrav1.ServerGroupFilter{ID: serverGroupID1},
145148
expect: func(m *mock.MockComputeClientMockRecorder) {
149+
m.GetServerGroup(serverGroupID1).Return(
150+
&servergroups.ServerGroup{ID: serverGroupID1, Name: "test-server-group"},
151+
nil)
146152
},
147153
want: serverGroupID1,
148154
wantErr: false,
@@ -382,7 +388,6 @@ func getDefaultInstanceSpec() *InstanceSpec {
382388
},
383389
ConfigDrive: *pointer.Bool(true),
384390
FailureDomain: *pointer.String(failureDomain),
385-
ServerGroupID: serverGroupUUID,
386391
Tags: []string{"test-tag"},
387392
SecurityGroups: []infrav1.SecurityGroupFilter{{ID: workerSecurityGroupUUID}},
388393
}
@@ -421,9 +426,6 @@ func TestService_ReconcileInstance(t *testing.T) {
421426
},
422427
},
423428
},
424-
"os:scheduler_hints": map[string]interface{}{
425-
"group": serverGroupUUID,
426-
},
427429
}
428430
}
429431

@@ -497,6 +499,12 @@ func TestService_ReconcileInstance(t *testing.T) {
497499
computeRecorder.GetFlavorFromName(flavorName).Return(&f, nil)
498500
}
499501

502+
// Expected calls when polling for server creation
503+
expectGetServerGroup := func(computeRecorder *mock.MockComputeClientMockRecorder, serverGroupUUID string) {
504+
serverGroup := servergroups.ServerGroup{ID: serverGroupUUID}
505+
computeRecorder.GetServerGroup(serverGroupUUID).Return(&serverGroup, nil)
506+
}
507+
500508
// Expected calls and custom match function for creating a server
501509
expectCreateServer := func(computeRecorder *mock.MockComputeClientMockRecorder, expectedCreateOpts map[string]interface{}, wantError bool) {
502510
// This nonsense is because ConfigDrive is a bool pointer, so we
@@ -576,6 +584,28 @@ func TestService_ReconcileInstance(t *testing.T) {
576584
},
577585
wantErr: false,
578586
},
587+
{
588+
name: "Boot with server group",
589+
getInstanceSpec: func() *InstanceSpec {
590+
s := getDefaultInstanceSpec()
591+
s.ServerGroupID = serverGroupUUID
592+
return s
593+
},
594+
expect: func(r *recorders) {
595+
expectUseExistingDefaultPort(r.network)
596+
expectDefaultImageAndFlavor(r.compute, r.image)
597+
598+
createMap := getDefaultServerMap()
599+
createMap["os:scheduler_hints"] = map[string]interface{}{
600+
"group": serverGroupUUID,
601+
}
602+
603+
expectGetServerGroup(r.compute, serverGroupUUID)
604+
expectCreateServer(r.compute, createMap, false)
605+
expectServerPollSuccess(r.compute)
606+
},
607+
wantErr: false,
608+
},
579609
{
580610
name: "Delete ports on server create error",
581611
getInstanceSpec: getDefaultInstanceSpec,

0 commit comments

Comments
 (0)