Skip to content

Commit f7c8fed

Browse files
authored
fix: fix some issues about database generator (#742)
1 parent 0dfe9c5 commit f7c8fed

File tree

8 files changed

+46
-33
lines changed

8 files changed

+46
-33
lines changed

pkg/modules/generators/accessories/mysql/alicloud_rds.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package mysql
22

33
import (
4+
"fmt"
5+
"os"
46
"strings"
57

68
v1 "k8s.io/api/core/v1"
@@ -14,6 +16,7 @@ import (
1416

1517
const (
1618
defaultAlicloudProviderURL = "registry.terraform.io/aliyun/alicloud/1.209.1"
19+
alicloudRegionEnv = "ALICLOUD_REGION"
1720
alicloudDBInstance = "alicloud_db_instance"
1821
alicloudDBConnection = "alicloud_db_connection"
1922
alicloudRDSAccount = "alicloud_rds_account"
@@ -58,9 +61,12 @@ func (g *mysqlGenerator) generateAlicloudResources(db *mysql.MySQL, spec *apiv1.
5861
}
5962

6063
// Get the alicloud provider region, and the region of the alicloud provider must be set.
61-
alicloudProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider])
62-
if err != nil {
63-
return nil, err
64+
var alicloudProviderRegion string
65+
if alicloudProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider]); alicloudProviderRegion == "" {
66+
alicloudProviderRegion = os.Getenv(alicloudRegionEnv)
67+
}
68+
if alicloudProviderRegion == "" {
69+
return nil, fmt.Errorf("alicloud provider region should not be empty")
6470
}
6571

6672
// Build alicloud_db_instance.

pkg/modules/generators/accessories/mysql/aws_rds.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package mysql
22

33
import (
44
"fmt"
5+
"os"
56

67
v1 "k8s.io/api/core/v1"
78

@@ -14,6 +15,7 @@ import (
1415

1516
const (
1617
defaultAWSProviderURL = "registry.terraform.io/hashicorp/aws/5.0.1"
18+
awsRegionEnv = "AWS_REGION"
1719
awsSecurityGroup = "aws_security_group"
1820
awsDBInstance = "aws_db_instance"
1921
)
@@ -62,9 +64,12 @@ func (g *mysqlGenerator) generateAWSResources(db *mysql.MySQL, spec *apiv1.Inten
6264
}
6365

6466
// Get the aws provider region, and the region of the aws provider must be set.
65-
awsProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider])
66-
if err != nil {
67-
return nil, err
67+
var awsProviderRegion string
68+
if awsProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider]); awsProviderRegion == "" {
69+
awsProviderRegion = os.Getenv(awsRegionEnv)
70+
}
71+
if awsProviderRegion == "" {
72+
return nil, fmt.Errorf("aws provider region should not be empty")
6873
}
6974

7075
// Build random_password for aws_db_instance.

pkg/modules/generators/accessories/mysql/mysql_generator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (g *mysqlGenerator) Generate(spec *apiv1.Intent) error {
119119
return err
120120
}
121121

122-
switch providerType {
122+
switch strings.ToLower(providerType) {
123123
case "aws":
124124
secret, err = g.generateAWSResources(db, spec)
125125
case "alicloud":

pkg/modules/generators/accessories/postgres/alicloud_rds.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package postgres
22

33
import (
4+
"fmt"
5+
"os"
46
"strings"
57

68
v1 "k8s.io/api/core/v1"
@@ -12,6 +14,7 @@ import (
1214

1315
const (
1416
defaultAlicloudProviderURL = "registry.terraform.io/aliyun/alicloud/1.209.1"
17+
alicloudRegionEnv = "ALICLOUD_REGION"
1518
alicloudDBInstance = "alicloud_db_instance"
1619
alicloudDBConnection = "alicloud_db_connection"
1720
alicloudRDSAccount = "alicloud_rds_account"
@@ -56,9 +59,12 @@ func (g *postgresGenerator) generateAlicloudResources(db *postgres.PostgreSQL, s
5659
}
5760

5861
// Get the alicloud provider region, and the region of the alicloud provider must be set.
59-
alicloudProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider])
60-
if err != nil {
61-
return nil, err
62+
var alicloudProviderRegion string
63+
if alicloudProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AlicloudProvider]); alicloudProviderRegion == "" {
64+
alicloudProviderRegion = os.Getenv(alicloudRegionEnv)
65+
}
66+
if alicloudProviderRegion == "" {
67+
return nil, fmt.Errorf("alicloud provider region should not be empty")
6268
}
6369

6470
// Build alicloud_db_instance.

pkg/modules/generators/accessories/postgres/aws_rds.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package postgres
22

33
import (
44
"fmt"
5+
"os"
56

67
v1 "k8s.io/api/core/v1"
78
apiv1 "kusionstack.io/kusion/pkg/apis/core/v1"
@@ -12,6 +13,7 @@ import (
1213

1314
const (
1415
defaultAWSProviderURL = "registry.terraform.io/hashicorp/aws/5.0.1"
16+
awsRegionEnv = "AWS_REGION"
1517
awsSecurityGroup = "aws_security_group"
1618
awsDBInstance = "aws_db_instance"
1719
)
@@ -60,9 +62,12 @@ func (g *postgresGenerator) generateAWSResources(db *postgres.PostgreSQL, spec *
6062
}
6163

6264
// Get the aws provider region, and the region of the aws provider must be set.
63-
awsProviderRegion, err := inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider])
64-
if err != nil {
65-
return nil, err
65+
var awsProviderRegion string
66+
if awsProviderRegion = inputs.GetProviderRegion(g.tfConfigs[inputs.AWSProvider]); awsProviderRegion == "" {
67+
awsProviderRegion = os.Getenv(awsRegionEnv)
68+
}
69+
if awsProviderRegion == "" {
70+
return nil, fmt.Errorf("aws provider region should not be empty")
6671
}
6772

6873
// Build random_password for aws_db_instance.

pkg/modules/generators/accessories/postgres/postgres_generator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (g *postgresGenerator) Generate(spec *apiv1.Intent) error {
119119
return err
120120
}
121121

122-
switch providerType {
122+
switch strings.ToLower(providerType) {
123123
case "aws":
124124
secret, err = g.generateAWSResources(db, spec)
125125
case "alicloud":

pkg/modules/inputs/provider.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
const (
1212
errInvalidProviderSource = "invalid provider source: %s"
1313
errEmptyProviderVersion = "empty provider version"
14-
errEmptyProviderRegion = "empty provider region for source: %s"
1514
)
1615

1716
const (
@@ -75,11 +74,11 @@ func GetProviderURL(providerConfig *apiv1.ProviderConfig) (string, error) {
7574
}
7675

7776
// GetProviderRegion returns the region of the terraform provider.
78-
func GetProviderRegion(providerConfig *apiv1.ProviderConfig) (string, error) {
77+
func GetProviderRegion(providerConfig *apiv1.ProviderConfig) string {
7978
region, ok := providerConfig.GenericConfig["region"]
8079
if !ok {
81-
return "", fmt.Errorf(errEmptyProviderRegion, providerConfig.Source)
80+
return ""
8281
}
8382

84-
return region.(string), nil
83+
return region.(string)
8584
}

pkg/modules/inputs/provider_test.go

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,9 @@ func TestGetProviderURL(t *testing.T) {
109109

110110
func TestGetProviderRegion(t *testing.T) {
111111
tests := []struct {
112-
name string
113-
data *apiv1.ProviderConfig
114-
expected string
115-
expectedErr error
112+
name string
113+
data *apiv1.ProviderConfig
114+
expected string
116115
}{
117116
{
118117
name: "Valid Provider Config",
@@ -123,27 +122,20 @@ func TestGetProviderRegion(t *testing.T) {
123122
"region": "us-east-1",
124123
},
125124
},
126-
expected: "us-east-1",
127-
expectedErr: nil,
125+
expected: "us-east-1",
128126
},
129127
{
130128
name: "Empty Provider Region",
131129
data: &apiv1.ProviderConfig{
132130
Source: "hashicorp/aws",
133131
Version: "5.0.1",
134132
},
135-
expected: "",
136-
expectedErr: fmt.Errorf(errEmptyProviderRegion, "hashicorp/aws"),
133+
expected: "",
137134
},
138135
}
139136

140137
for _, test := range tests {
141-
actual, actualErr := GetProviderRegion(test.data)
142-
if test.expectedErr == nil {
143-
assert.Equal(t, test.expected, actual)
144-
assert.NoError(t, actualErr)
145-
} else {
146-
assert.ErrorContains(t, actualErr, test.expectedErr.Error())
147-
}
138+
actual := GetProviderRegion(test.data)
139+
assert.Equal(t, test.expected, actual)
148140
}
149141
}

0 commit comments

Comments
 (0)