Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Migration Script for New Deployment #908

Merged
merged 5 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def extent(setup: EnvironmentSetup):
logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain')
vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId)
subnet_ids = domain.subnetIds
security_groups = []
else:
cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn(
accountid=_environment.AwsAccountId, region=_environment.region
Expand All @@ -55,7 +54,6 @@ def extent(setup: EnvironmentSetup):
subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]
subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets]
subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets]
security_groups = []
else:
logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...")
# Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways
Expand Down Expand Up @@ -95,19 +93,20 @@ def extent(setup: EnvironmentSetup):
resource_type=ec2.FlowLogResourceType.from_vpc(vpc),
destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role)
)
# setup security group to be used for sagemaker studio domain
sagemaker_sg = ec2.SecurityGroup(
setup,
"SecurityGroup",
vpc=vpc,
description="Security Group for SageMaker Studio",
security_group_name=domain.sagemakerStudioDomainName,
)

sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic())
security_groups = [sagemaker_sg.security_group_id]
subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets]

# setup security group to be used for sagemaker studio domain
sagemaker_sg = ec2.SecurityGroup(
setup,
"SecurityGroup",
vpc=vpc,
description="Security Group for SageMaker Studio",
security_group_name=domain.sagemakerStudioDomainName,
)

sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic())
security_groups = [sagemaker_sg.security_group_id]

vpc_id = vpc.vpc_id

sagemaker_domain_role = iam.Role(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,77 +61,73 @@ def upgrade():
1) update of the sagemaker_studio_domain table to include SageMaker Studio Domain VPC Information
"""
try:
envname = os.getenv('envname', 'local')
engine = get_engine(envname=envname).engine

bind = op.get_bind()
session = orm.Session(bind=bind)

if has_table('sagemaker_studio_domain', engine):
print("Updating sagemaker_studio_domain table...")
op.alter_column(
'sagemaker_studio_domain',
'sagemakerStudioDomainID',
nullable=True,
existing_type=sa.String()
)
op.alter_column(
'sagemaker_studio_domain',
'SagemakerStudioStatus',
nullable=True,
existing_type=sa.String()
print("Updating sagemaker_studio_domain table...")
op.alter_column(
'sagemaker_studio_domain',
'sagemakerStudioDomainID',
nullable=True,
existing_type=sa.String()
)
op.alter_column(
'sagemaker_studio_domain',
'SagemakerStudioStatus',
nullable=True,
existing_type=sa.String()
)
op.alter_column(
'sagemaker_studio_domain',
'RoleArn',
new_column_name='DefaultDomainRoleName',
nullable=False,
existing_type=sa.String()
)

op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False))
op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True))
op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False))

op.create_foreign_key(
"fk_sagemaker_studio_domain_env_uri",
"sagemaker_studio_domain", "environment",
["environmentUri"], ["environmentUri"],
)

print("Update sagemaker_studio_domain table done.")
print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...")

env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter(
and_(
EnvironmentParameter.key == "mlStudiosEnabled",
EnvironmentParameter.value == "true"
)
op.alter_column(
'sagemaker_studio_domain',
'RoleArn',
new_column_name='DefaultDomainRoleName',
nullable=False,
existing_type=sa.String()
)

op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False))
op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True))
op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True))
op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False))

op.create_foreign_key(
"fk_sagemaker_studio_domain_env_uri",
"sagemaker_studio_domain", "environment",
["environmentUri"], ["environmentUri"],
)

print("Update sagemaker_studio_domain table done.")
print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...")

env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter(
and_(
EnvironmentParameter.key == "mlStudiosEnabled",
EnvironmentParameter.value == "true"
).all()
for param in env_mlstudio_parameters:
env: Environment = session.query(Environment).filter(
Environment.environmentUri == param.environmentUri
).first()

domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter(
SagemakerStudioDomain.environmentUri == env.environmentUri
).first()
if not domain:
domain = SagemakerStudioDomain(
label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
owner=env.owner,
description='No description provided',
environmentUri=env.environmentUri,
AWSAccountId=env.AwsAccountId,
region=env.region,
DefaultDomainRoleName="RoleSagemakerStudioUsers",
sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
vpcType="unknown",
SamlGroupName=env.SamlGroupName
)
).all()
for param in env_mlstudio_parameters:
env: Environment = session.query(Environment).filter(
Environment.environmentUri == param.environmentUri
).first()

domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter(
SagemakerStudioDomain.environmentUri == env.environmentUri
).first()
if not domain:
domain = SagemakerStudioDomain(
label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
owner=env.owner,
description='No description provided',
environmentUri=env.environmentUri,
AWSAccountId=env.AwsAccountId,
region=env.region,
DefaultDomainRoleName="RoleSagemakerStudioUsers",
sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}",
vpcType="unknown",
SamlGroupName=env.SamlGroupName
)
session.add(domain)
session.add(domain)
session.flush()
session.commit()
print("Fill of sagemaker_studio_domain table is done")
Expand Down
4 changes: 1 addition & 3 deletions deploy/cdk_exec_policy/cdkExecPolicy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ Resources:
Effect: Allow
Action:
- 'sagemaker:*Tag*'
- 'sagemaker:CreateDomain'
- 'sagemaker:DeleteDomain'
- 'sagemaker:DescribeDomain'
- 'sagemaker:*Domain'
- 'sagemaker:CreateApp'
- 'sagemaker:CreateUserProfile'
- 'sagemaker:DescribeUserProfile'
Expand Down
Loading