Skip to content

Commit

Permalink
Bring the Vault provider to gcp sql parity with Vault (#2012)
Browse files Browse the repository at this point in the history
Add support for "auth_type" and "service_account_json" fields for MySQL and Postgres db secret types
  • Loading branch information
kpcraig authored Oct 2, 2023
1 parent 4a5cd4a commit 7f71cc3
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 16 deletions.
101 changes: 89 additions & 12 deletions vault/resource_database_secret_backend_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type connectionStringConfig struct {
excludeUsernameTemplate bool
includeUserPass bool
includeDisableEscaping bool
isCloud bool
}

const (
Expand Down Expand Up @@ -569,6 +570,7 @@ func getDatabaseSchema(typ schema.ValueType) schemaMap {
Elem: connectionStringResource(&connectionStringConfig{
includeUserPass: true,
includeDisableEscaping: true,
isCloud: true,
}),
MaxItems: 1,
ConflictsWith: util.CalculateConflictsWith(dbEnginePostgres.Name(), dbEngineTypes),
Expand Down Expand Up @@ -765,6 +767,20 @@ func connectionStringResource(config *connectionStringConfig) *schema.Resource {
}
}

if config.isCloud {
res.Schema["auth_type"] = &schema.Schema{
Type: schema.TypeString,
Optional: true,
Description: "Specify alternative authorization type. (Only 'gcp_iam' is valid currently)",
}
res.Schema["service_account_json"] = &schema.Schema{
Type: schema.TypeString,
Optional: true,
Description: "A JSON encoded credential for use with IAM authorization",
Sensitive: true,
}
}

if !config.excludeUsernameTemplate {
res.Schema["username_template"] = &schema.Schema{
Type: schema.TypeString,
Expand All @@ -787,6 +803,7 @@ func connectionStringResource(config *connectionStringConfig) *schema.Resource {
func mysqlConnectionStringResource() *schema.Resource {
r := connectionStringResource(&connectionStringConfig{
includeUserPass: true,
isCloud: true,
})
r.Schema["tls_certificate_key"] = &schema.Schema{
Type: schema.TypeString,
Expand Down Expand Up @@ -866,7 +883,7 @@ func getDBEngineFromResp(engines []*dbEngine, r *api.Secret) (*dbEngine, error)
return nil, fmt.Errorf("no supported database engines found for plugin %q", pluginName)
}

func getDatabaseAPIDataForEngine(engine *dbEngine, idx int, d *schema.ResourceData) (map[string]interface{}, error) {
func getDatabaseAPIDataForEngine(engine *dbEngine, idx int, d *schema.ResourceData, meta interface{}) (map[string]interface{}, error) {
prefix := engine.ResourcePrefix(idx)
data := map[string]interface{}{}

Expand All @@ -893,7 +910,7 @@ func getDatabaseAPIDataForEngine(engine *dbEngine, idx int, d *schema.ResourceDa
case dbEngineMSSQL:
setMSSQLDatabaseConnectionData(d, prefix, data)
case dbEngineMySQL:
setMySQLDatabaseConnectionData(d, prefix, data)
setMySQLDatabaseConnectionData(d, prefix, data, meta)
case dbEngineMySQLRDS:
setDatabaseConnectionDataWithUserPass(d, prefix, data)
case dbEngineMySQLAurora:
Expand All @@ -903,7 +920,7 @@ func getDatabaseAPIDataForEngine(engine *dbEngine, idx int, d *schema.ResourceDa
case dbEngineOracle:
setDatabaseConnectionDataWithUserPass(d, prefix, data)
case dbEnginePostgres:
setDatabaseConnectionDataWithDisableEscaping(d, prefix, data)
setPostgresDatabaseConnectionData(d, prefix, data, meta)
case dbEngineElasticSearch:
setElasticsearchDatabaseConnectionData(d, prefix, data)
case dbEngineRedis:
Expand Down Expand Up @@ -1028,6 +1045,7 @@ func getConnectionDetailsFromResponse(d *schema.ResourceData, prefix string, res
result["username_template"] = v.(string)
}
}

return result
}

Expand All @@ -1049,6 +1067,31 @@ func getMSSQLConnectionDetailsFromResponse(d *schema.ResourceData, prefix string
return result, nil
}

func getPostgresConnectionDetailsFromResponse(d *schema.ResourceData, prefix string, resp *api.Secret, meta interface{}) map[string]interface{} {
result := getConnectionDetailsFromResponseWithDisableEscaping(d, prefix, resp)
details := resp.Data["connection_details"]
data, ok := details.(map[string]interface{})
if !ok {
return nil
}

// cloud specific
if provider.IsAPISupported(meta, provider.VaultVersion115) {
if v, ok := data["auth_type"]; ok {
result["auth_type"] = v.(string)
}
if v, ok := d.GetOk(prefix + "service_account_json"); ok {
result["service_account_json"] = v.(string)
} else {
if v, ok := data["service_account_json"]; ok {
result["service_account_json"] = v.(string)
}
}
}

return result
}

func getConnectionDetailsFromResponseWithDisableEscaping(d *schema.ResourceData, prefix string, resp *api.Secret) map[string]interface{} {
result := getConnectionDetailsFromResponseWithUserPass(d, prefix, resp)
if result == nil {
Expand All @@ -1063,7 +1106,7 @@ func getConnectionDetailsFromResponseWithDisableEscaping(d *schema.ResourceData,
return result
}

func getMySQLConnectionDetailsFromResponse(d *schema.ResourceData, prefix string, resp *api.Secret) map[string]interface{} {
func getMySQLConnectionDetailsFromResponse(d *schema.ResourceData, prefix string, resp *api.Secret, meta interface{}) map[string]interface{} {
result := getConnectionDetailsFromResponseWithUserPass(d, prefix, resp)
details := resp.Data["connection_details"]
data, ok := details.(map[string]interface{})
Expand All @@ -1084,6 +1127,21 @@ func getMySQLConnectionDetailsFromResponse(d *schema.ResourceData, prefix string
result["tls_ca"] = v.(string)
}
}

if provider.IsAPISupported(meta, provider.VaultVersion115) {
// cloud specific
if v, ok := data["auth_type"]; ok {
result["auth_type"] = v.(string)
}
if v, ok := d.GetOk(prefix + "service_account_json"); ok {
result["service_account_json"] = v.(string)
} else {
if v, ok := data["service_account_json"]; ok {
result["service_account_json"] = v.(string)
}
}
}

return result
}

Expand Down Expand Up @@ -1367,6 +1425,18 @@ func setDatabaseConnectionData(d *schema.ResourceData, prefix string, data map[s
}
}

func setCloudDatabaseConnectionData(d *schema.ResourceData, prefix string, data map[string]interface{}, meta interface{}) {
if !provider.IsAPISupported(meta, provider.VaultVersion115) {
return
}
if v, ok := d.GetOk(prefix + "auth_type"); ok {
data["auth_type"] = v.(string)
}
if v, ok := d.GetOk(prefix + "service_account_json"); ok {
data["service_account_json"] = v.(string)
}
}

func setMSSQLDatabaseConnectionData(d *schema.ResourceData, prefix string, data map[string]interface{}) {
setDatabaseConnectionDataWithDisableEscaping(d, prefix, data)
if v, ok := d.GetOk(prefix + "contained_db"); ok {
Expand All @@ -1378,8 +1448,9 @@ func setMSSQLDatabaseConnectionData(d *schema.ResourceData, prefix string, data
}
}

func setMySQLDatabaseConnectionData(d *schema.ResourceData, prefix string, data map[string]interface{}) {
func setMySQLDatabaseConnectionData(d *schema.ResourceData, prefix string, data map[string]interface{}, meta interface{}) {
setDatabaseConnectionDataWithUserPass(d, prefix, data)
setCloudDatabaseConnectionData(d, prefix, data, meta)
if v, ok := d.GetOk(prefix + "tls_certificate_key"); ok {
data["tls_certificate_key"] = v.(string)
}
Expand All @@ -1388,6 +1459,12 @@ func setMySQLDatabaseConnectionData(d *schema.ResourceData, prefix string, data
}
}

func setPostgresDatabaseConnectionData(d *schema.ResourceData, prefix string, data map[string]interface{}, meta interface{}) {
setDatabaseConnectionDataWithDisableEscaping(d, prefix, data)
setCloudDatabaseConnectionData(d, prefix, data, meta)

}

func setRedisDatabaseConnectionData(d *schema.ResourceData, prefix string, data map[string]interface{}) {
if v, ok := d.GetOk(prefix + "host"); ok {
data["host"] = v.(string)
Expand Down Expand Up @@ -1593,7 +1670,7 @@ func databaseSecretBackendConnectionCreateOrUpdate(
path := databaseSecretBackendConnectionPath(
d.Get("backend").(string), d.Get("name").(string))
if err := writeDatabaseSecretConfig(
d, client, engine, 0, false, path); err != nil {
d, client, engine, 0, false, path, meta); err != nil {
return err
}

Expand All @@ -1604,9 +1681,9 @@ func databaseSecretBackendConnectionCreateOrUpdate(
}

func writeDatabaseSecretConfig(d *schema.ResourceData, client *api.Client,
engine *dbEngine, idx int, unifiedSchema bool, path string,
engine *dbEngine, idx int, unifiedSchema bool, path string, meta interface{},
) error {
data, err := getDatabaseAPIDataForEngine(engine, idx, d)
data, err := getDatabaseAPIDataForEngine(engine, idx, d, meta)
if err != nil {
return err
}
Expand Down Expand Up @@ -1729,7 +1806,7 @@ func databaseSecretBackendConnectionRead(d *schema.ResourceData, meta interface{
return err
}

result, err := getDBConnectionConfig(d, engine, 0, resp)
result, err := getDBConnectionConfig(d, engine, 0, resp, meta)
if err != nil {
return err
}
Expand Down Expand Up @@ -1785,7 +1862,7 @@ func getDBCommonConfig(d *schema.ResourceData, resp *api.Secret,
}

func getDBConnectionConfig(d *schema.ResourceData, engine *dbEngine, idx int,
resp *api.Secret,
resp *api.Secret, meta interface{},
) (map[string]interface{}, error) {
var result map[string]interface{}

Expand Down Expand Up @@ -1814,7 +1891,7 @@ func getDBConnectionConfig(d *schema.ResourceData, engine *dbEngine, idx int,
}
result = values
case dbEngineMySQL:
result = getMySQLConnectionDetailsFromResponse(d, prefix, resp)
result = getMySQLConnectionDetailsFromResponse(d, prefix, resp, meta)
case dbEngineMySQLRDS:
result = getConnectionDetailsFromResponseWithUserPass(d, prefix, resp)
case dbEngineMySQLAurora:
Expand All @@ -1824,7 +1901,7 @@ func getDBConnectionConfig(d *schema.ResourceData, engine *dbEngine, idx int,
case dbEngineOracle:
result = getConnectionDetailsFromResponseWithUserPass(d, prefix, resp)
case dbEnginePostgres:
result = getConnectionDetailsFromResponseWithDisableEscaping(d, prefix, resp)
result = getPostgresConnectionDetailsFromResponse(d, prefix, resp, meta)
case dbEngineElasticSearch:
result = getElasticsearchConnectionDetailsFromResponse(d, prefix, resp)
case dbEngineSnowflake:
Expand Down
124 changes: 124 additions & 0 deletions vault/resource_database_secret_backend_connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,47 @@ func TestAccDatabaseSecretBackendConnection_mssql(t *testing.T) {
})
}

func TestAccDatabaseSecretBackendConnection_mysql_cloud(t *testing.T) {
// wanted this to be the included with the following test, but the env-var check is different
values := testutil.SkipTestEnvUnset(t, "MYSQL_CLOUD_CONNECTION_URL", "MYSQL_CLOUD_CONNECTION_SERVICE_ACCOUNT_JSON")
connURL, saJSON := values[0], values[1]

backend := acctest.RandomWithPrefix("tf-test-db")
name := acctest.RandomWithPrefix("db")
resource.Test(t, resource.TestCase{
ProviderFactories: providerFactories,
PreCheck: func() {
testutil.TestAccPreCheck(t)
SkipIfAPIVersionLT(t, testProvider.Meta(), provider.VaultVersion115)
},
CheckDestroy: testAccDatabaseSecretBackendConnectionCheckDestroy,
Steps: []resource.TestStep{
{
Config: testAccDatabaseSecretBackendConnectionConfig_mysql_cloud(name, backend, connURL, "gcp_iam", saJSON),
Check: testComposeCheckFuncCommonDatabaseSecretBackend(name, backend, dbEngineMySQL.DefaultPluginName(),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "allowed_roles.#", "2"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "allowed_roles.0", "dev"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "allowed_roles.1", "prod"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "root_rotation_statements.#", "1"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "root_rotation_statements.0", "FOOBAR"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "verify_connection", "true"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "mysql.0.connection_url", connURL),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "mysql.0.auth_type", "gcp_iam"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "mysql.0.max_open_connections", "2"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "mysql.0.max_idle_connections", "0"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "mysql.0.max_connection_lifetime", "0"),
),
},
{
ResourceName: testDefaultDatabaseSecretBackendResource,
ImportState: true,
ImportStateVerify: true,
ImportStateVerifyIgnore: []string{"verify_connection", "mysql.0.service_account_json"},
},
},
})
}

func TestAccDatabaseSecretBackendConnection_mysql(t *testing.T) {
MaybeSkipDBTests(t, dbEngineMySQL)

Expand Down Expand Up @@ -776,6 +817,45 @@ func TestAccDatabaseSecretBackendConnection_postgresql(t *testing.T) {
})
}

func TestAccDatabaseSecretBackendConnection_postgresql_cloud(t *testing.T) {
// wanted this to be the included with the following test, but the env-var check is different
values := testutil.SkipTestEnvUnset(t, "POSTGRES_CLOUD_URL", "POSTGRES_CLOUD_SERVICE_ACCOUNT_JSON")
connURL, saJSON := values[0], values[1]

backend := acctest.RandomWithPrefix("tf-test-db")
name := acctest.RandomWithPrefix("db")
resource.Test(t, resource.TestCase{
ProviderFactories: providerFactories,
PreCheck: func() {
testutil.TestAccPreCheck(t)
SkipIfAPIVersionLT(t, testProvider.Meta(), provider.VaultVersion115)
},
CheckDestroy: testAccDatabaseSecretBackendConnectionCheckDestroy,
Steps: []resource.TestStep{
{
Config: testAccDatabaseSecretBackendConnectionConfig_postgres_cloud(name, backend, connURL, "gcp_iam", saJSON),
Check: testComposeCheckFuncCommonDatabaseSecretBackend(name, backend, dbEngineMySQL.DefaultPluginName(),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "allowed_roles.#", "2"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "allowed_roles.0", "dev"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "allowed_roles.1", "prod"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "root_rotation_statements.#", "1"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "root_rotation_statements.0", "FOOBAR"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "verify_connection", "true"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "postgresql.0.connection_url", connURL),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "postgresql.0.disable_escaping", "true"),
resource.TestCheckResourceAttr(testDefaultDatabaseSecretBackendResource, "postgresql.0.auth_type", "gcp_iam"),
),
},
{
ResourceName: testDefaultDatabaseSecretBackendResource,
ImportState: true,
ImportStateVerify: true,
ImportStateVerifyIgnore: []string{"verify_connection", "postgres.0.service_account_json"},
},
},
})
}

func TestAccDatabaseSecretBackendConnection_elasticsearch(t *testing.T) {
MaybeSkipDBTests(t, dbEngineElasticSearch)

Expand Down Expand Up @@ -1519,6 +1599,28 @@ resource "vault_database_secret_backend_connection" "test" {
`, path, name, connURL, username, password)
}

func testAccDatabaseSecretBackendConnectionConfig_mysql_cloud(name, path, connURL, authType, serviceAccountJSON string) string {
return fmt.Sprintf(`
resource "vault_mount" "db" {
path = "%s"
type = "database"
}
resource "vault_database_secret_backend_connection" "test" {
backend = vault_mount.db.path
name = "%s"
allowed_roles = ["dev", "prod"]
root_rotation_statements = ["FOOBAR"]
mysql {
connection_url = "%s"
auth_type = "%s"
service_account_json = "%s"
}
}
`, path, name, connURL, authType, serviceAccountJSON)
}

func testAccDatabaseSecretBackendConnectionConfig_postgresql(name, path, userTempl, username, password, openConn, idleConn, maxConnLifetime string, parsedURL *url.URL) string {
return fmt.Sprintf(`
resource "vault_mount" "db" {
Expand Down Expand Up @@ -1566,6 +1668,28 @@ resource "vault_database_secret_backend_connection" "test" {
`, path, name, parsedURL.String())
}

func testAccDatabaseSecretBackendConnectionConfig_postgres_cloud(name, path, connURL, authType, serviceAccountJSON string) string {
return fmt.Sprintf(`
resource "vault_mount" "db" {
path = "%s"
type = "database"
}
resource "vault_database_secret_backend_connection" "test" {
backend = vault_mount.db.path
name = "%s"
allowed_roles = ["dev", "prod"]
root_rotation_statements = ["FOOBAR"]
postgresql {
connection_url = "%s"
auth_type = "%s"
service_account_json = "%s"
}
}
`, path, name, connURL, authType, serviceAccountJSON)
}

func testAccDatabaseSecretBackendConnectionConfig_snowflake(name, path, url, username, password, userTempl string) string {
return fmt.Sprintf(`
resource "vault_mount" "db" {
Expand Down
Loading

0 comments on commit 7f71cc3

Please sign in to comment.