From ecd619d55a0c125ffdace0b439e40feb4a2b6b76 Mon Sep 17 00:00:00 2001 From: Jimmi Dyson Date: Wed, 6 Sep 2023 16:59:32 +0100 Subject: [PATCH] feat: Update variable getter to handle nested fields --- .../clustertopology/variables/variable.go | 24 ++++++++- .../variables/variables_test.go | 54 +++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/pkg/capi/clustertopology/variables/variable.go b/pkg/capi/clustertopology/variables/variable.go index 632f155cf..a7f053169 100644 --- a/pkg/capi/clustertopology/variables/variable.go +++ b/pkg/capi/clustertopology/variables/variable.go @@ -7,6 +7,7 @@ import ( "encoding/json" apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "sigs.k8s.io/cluster-api/exp/runtime/topologymutation" ) @@ -14,12 +15,33 @@ import ( func Get[T any]( variables map[string]apiextensionsv1.JSON, name string, + fields ...string, ) (value T, found bool, err error) { variable, found, err := topologymutation.GetVariable(variables, name) if err != nil || !found { return value, found, err } - err = json.Unmarshal(variable.Raw, &value) + jsonValue := variable.Raw + + if len(fields) > 0 { + var unstr map[string]interface{} + err = json.Unmarshal(jsonValue, &unstr) + if err != nil { + return value, found, err + } + + nestedField, found, err := unstructured.NestedFieldCopy(unstr, fields...) + if err != nil || !found { + return value, found, err + } + + jsonValue, err = json.Marshal(nestedField) + if err != nil { + return value, found, err + } + } + + err = json.Unmarshal(jsonValue, &value) return value, err == nil, err } diff --git a/pkg/capi/clustertopology/variables/variables_test.go b/pkg/capi/clustertopology/variables/variables_test.go index b26a48bea..073e06d69 100644 --- a/pkg/capi/clustertopology/variables/variables_test.go +++ b/pkg/capi/clustertopology/variables/variables_test.go @@ -51,3 +51,57 @@ func TestGetVariable_ParseError(t *testing.T) { g.Expect(found).To(BeFalse()) g.Expect(parsed).To(BeEmpty()) } + +func TestGet_ValidNestedFieldAsStruct(t *testing.T) { + g := NewWithT(t) + + type nestedStruct struct { + Bar string `json:"bar"` + } + sampleValue := []byte(`{"foo": {"bar": "baz"}}`) + vars := map[string]apiextensionsv1.JSON{ + "sampleVar": {Raw: sampleValue}, + } + parsed, found, err := variables.Get[nestedStruct](vars, "sampleVar", "foo") + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(found).To(BeTrue()) + g.Expect(parsed).To(Equal(nestedStruct{ + Bar: "baz", + })) +} + +func TestGet_ValidNestedFieldAsScalar(t *testing.T) { + g := NewWithT(t) + + sampleValue := []byte(`{"foo": {"bar": "baz"}}`) + vars := map[string]apiextensionsv1.JSON{ + "sampleVar": {Raw: sampleValue}, + } + parsed, found, err := variables.Get[string](vars, "sampleVar", "foo", "bar") + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(found).To(BeTrue()) + g.Expect(parsed).To(Equal("baz")) +} + +func TestGet_InvalidNestedFieldType(t *testing.T) { + g := NewWithT(t) + + sampleValue := []byte(`{"foo": {"bar": "baz"}}`) + vars := map[string]apiextensionsv1.JSON{ + "sampleVar": {Raw: sampleValue}, + } + _, _, err := variables.Get[int](vars, "sampleVar", "foo", "bar") + g.Expect(err).To(HaveOccurred()) +} + +func TestGet_MissingNestedField(t *testing.T) { + g := NewWithT(t) + + sampleValue := []byte(`{"foo": {"bar": "baz"}}`) + vars := map[string]apiextensionsv1.JSON{ + "sampleVar": {Raw: sampleValue}, + } + _, found, err := variables.Get[string](vars, "sampleVar", "foo", "nonexistent") + g.Expect(err).NotTo(HaveOccurred()) + g.Expect(found).To(BeFalse()) +}