Skip to content

Commit

Permalink
Fix Migration Script for New Deployment (#908)
Browse files Browse the repository at this point in the history
### Feature or Bugfix
<!-- please choose -->
- Bugfix

### Detail
- When deploying from scratch - the migration script for byo vpc for
mlstudio had a line of code `if has_table('sagemaker_studio_domain')`
which was evaluating to `False` and skipping necessary migration code
- Removed the `if has_table()` and just throw Exception in try/catch
block for upgrade migration script

- Also as part of this PR I made sure the security group gets created in
the mlstudio_extension stack no matter what type of VPC we use for ML
Studio Domain
- Without the Security Group VpcOnly Deployments of ML Studio Domain
with `security_group=[]` will never be able to connect to internet
without manual changes
- Tested with Imported VPC and provisioned Studio User Profile - can
access internet via NAT Gateway as prescribed with VPCOnly deployments
if internet access is required
- Tested changes to ML Studio Domain Security Group does NOT cause new
creation of Domain (only UpdateDomain no replacement) --> Environment
Stack Updates Successfully


### Relates
- #894

### Security
N/A
Please answer the questions below briefly where applicable, or write
`N/A`. Based on
[OWASP 10](https://owasp.org/Top10/en/).

- Does this PR introduce or modify any input fields or queries - this
includes
fetching data from storage outside the application (e.g. a database, an
S3 bucket)?
  - Is the input sanitized?
- What precautions are you taking before deserializing the data you
consume?
  - Is injection prevented by parametrizing queries?
  - Have you ensured no `eval` or similar functions are used?
- Does this PR introduce any functionality or component that requires
authorization?
- How have you ensured it respects the existing AuthN/AuthZ mechanisms?
  - Are you logging failed auth attempts?
- Are you using or adding any cryptographic features?
  - Do you use a standard proven implementations?
  - Are the used keys controlled by the customer? Where are they stored?
- Are you introducing any new policies/roles/users?
  - Have you used the least-privilege principle? How?


By submitting this pull request, I confirm that my contribution is made
under the terms of the Apache 2.0 license.
  • Loading branch information
noah-paige authored Dec 13, 2023
1 parent 4090107 commit ad70a88
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 82 deletions.
25 changes: 12 additions & 13 deletions backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def extent(setup: EnvironmentSetup):
logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain')
vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId)
subnet_ids = domain.subnetIds
security_groups = []
else:
cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn(
accountid=_environment.AwsAccountId, region=_environment.region
Expand All @@ -55,7 +54,6 @@ def extent(setup: EnvironmentSetup):
subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]
subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets]
subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets]
security_groups = []
else:
logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...")
# Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways
Expand Down Expand Up @@ -95,19 +93,20 @@ def extent(setup: EnvironmentSetup):
resource_type=ec2.FlowLogResourceType.from_vpc(vpc),
destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role)
)
# setup security group to be used for sagemaker studio domain
sagemaker_sg = ec2.SecurityGroup(
setup,
"SecurityGroup",
vpc=vpc,
description="Security Group for SageMaker Studio",
security_group_name=domain.sagemakerStudioDomainName,
)

sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic())
security_groups = [sagemaker_sg.security_group_id]
subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]

# setup security group to be used for sagemaker studio domain
sagemaker_sg = ec2.SecurityGroup(
setup,
"SecurityGroup",
vpc=vpc,
description="Security Group for SageMaker Studio",
security_group_name=domain.sagemakerStudioDomainName,
)

sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic())
security_groups = [sagemaker_sg.security_group_id]

vpc_id = vpc.vpc_id

sagemaker_domain_role = iam.Role(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,77 +61,73 @@ def upgrade():
1) update of the sagemaker_studio_domain table to include SageMaker Studio Domain VPC Information
"""
try:
envname = os.getenv('envname', 'local')
engine = get_engine(envname=envname).engine

bind = op.get_bind()
session = orm.Session(bind=bind)

if has_table('sagemaker_studio_domain', engine):
print("Updating sagemaker_studio_domain table...")
op.alter_column(
'sagemaker_studio_domain',
'sagemakerStudioDomainID',
nullable=True,
existing_type=sa.String()
)
op.alter_column(
'sagemaker_studio_domain',
'SagemakerStudioStatus',
nullable=True,
existing_type=sa.String()
print("Updating sagemaker_studio_domain table...")
op.alter_column(
'sagemaker_studio_domain',
'sagemakerStudioDomainID',
nullable=True,
existing_type=sa.String()
)
op.alter_column(
'sagemaker_studio_domain',
'SagemakerStudioStatus',
nullable=True,
existing_type=sa.String()
)
op.alter_column(
'sagemaker_studio_domain',
'RoleArn',
new_column_name='DefaultDomainRoleName',
nullable=False,
existing_type=sa.String()
)

op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False))
op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True))
op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False))

op.create_foreign_key(
"fk_sagemaker_studio_domain_env_uri",
"sagemaker_studio_domain", "environment",
["environmentUri"], ["environmentUri"],
)

print("Update sagemaker_studio_domain table done.")
print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...")

env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter(
and_(
EnvironmentParameter.key == "mlStudiosEnabled",
EnvironmentParameter.value == "true"
)
op.alter_column(
'sagemaker_studio_domain',
'RoleArn',
new_column_name='DefaultDomainRoleName',
nullable=False,
existing_type=sa.String()
)

op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False))
op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True))
op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False))

op.create_foreign_key(
"fk_sagemaker_studio_domain_env_uri",
"sagemaker_studio_domain", "environment",
["environmentUri"], ["environmentUri"],
)

print("Update sagemaker_studio_domain table done.")
print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...")

env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter(
and_(
EnvironmentParameter.key == "mlStudiosEnabled",
EnvironmentParameter.value == "true"
).all()
for param in env_mlstudio_parameters:
env: Environment = session.query(Environment).filter(
Environment.environmentUri == param.environmentUri
).first()

domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter(
SagemakerStudioDomain.environmentUri == env.environmentUri
).first()
if not domain:
domain = SagemakerStudioDomain(
label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
owner=env.owner,
description='No description provided',
environmentUri=env.environmentUri,
AWSAccountId=env.AwsAccountId,
region=env.region,
DefaultDomainRoleName="RoleSagemakerStudioUsers",
sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
vpcType="unknown",
SamlGroupName=env.SamlGroupName
)
).all()
for param in env_mlstudio_parameters:
env: Environment = session.query(Environment).filter(
Environment.environmentUri == param.environmentUri
).first()

domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter(
SagemakerStudioDomain.environmentUri == env.environmentUri
).first()
if not domain:
domain = SagemakerStudioDomain(
label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
owner=env.owner,
description='No description provided',
environmentUri=env.environmentUri,
AWSAccountId=env.AwsAccountId,
region=env.region,
DefaultDomainRoleName="RoleSagemakerStudioUsers",
sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
vpcType="unknown",
SamlGroupName=env.SamlGroupName
)
session.add(domain)
session.add(domain)
session.flush()
session.commit()
print("Fill of sagemaker_studio_domain table is done")
Expand Down
4 changes: 1 addition & 3 deletions deploy/cdk_exec_policy/cdkExecPolicy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ Resources:
Effect: Allow
Action:
- 'sagemaker:*Tag*'
- 'sagemaker:CreateDomain'
- 'sagemaker:DeleteDomain'
- 'sagemaker:DescribeDomain'
- 'sagemaker:*Domain'
- 'sagemaker:CreateApp'
- 'sagemaker:CreateUserProfile'
- 'sagemaker:DescribeUserProfile'
Expand Down

0 comments on commit ad70a88

Please sign in to comment.