diff --git a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py index 49082ccfb..2305a0e28 100644 --- a/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py +++ b/backend/dataall/modules/mlstudio/cdk/mlstudio_extension.py @@ -40,7 +40,6 @@ def extent(setup: EnvironmentSetup): logger.info(f'Using VPC {domain.vpcId} and subnets {domain.subnetIds} for SageMaker Studio domain') vpc = ec2.Vpc.from_lookup(setup, 'VPCStudio', vpc_id=domain.vpcId) subnet_ids = domain.subnetIds - security_groups = [] else: cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn( accountid=_environment.AwsAccountId, region=_environment.region @@ -55,7 +54,6 @@ def extent(setup: EnvironmentSetup): subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] subnet_ids += [public_subnet.subnet_id for public_subnet in vpc.public_subnets] subnet_ids += [isolated_subnet.subnet_id for isolated_subnet in vpc.isolated_subnets] - security_groups = [] else: logger.info("Default VPC not found, Exception. Creating a VPC for SageMaker resources...") # Create VPC with 3 Public Subnets and 3 Private subnets wit NAT Gateways @@ -95,19 +93,20 @@ def extent(setup: EnvironmentSetup): resource_type=ec2.FlowLogResourceType.from_vpc(vpc), destination=ec2.FlowLogDestination.to_cloud_watch_logs(log_group, vpc_flow_role) ) - # setup security group to be used for sagemaker studio domain - sagemaker_sg = ec2.SecurityGroup( - setup, - "SecurityGroup", - vpc=vpc, - description="Security Group for SageMaker Studio", - security_group_name=domain.sagemakerStudioDomainName, - ) - - sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) - security_groups = [sagemaker_sg.security_group_id] subnet_ids = [private_subnet.subnet_id for private_subnet in vpc.private_subnets] + # setup security group to be used for sagemaker studio domain + sagemaker_sg = ec2.SecurityGroup( + setup, + "SecurityGroup", + vpc=vpc, + description="Security Group for SageMaker Studio", + security_group_name=domain.sagemakerStudioDomainName, + ) + + sagemaker_sg.add_ingress_rule(sagemaker_sg, ec2.Port.all_traffic()) + security_groups = [sagemaker_sg.security_group_id] + vpc_id = vpc.vpc_id sagemaker_domain_role = iam.Role( diff --git a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py index a3ac794f3..1d376097b 100644 --- a/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py +++ b/backend/migrations/versions/71a5f5de322f_update_sagemaker_studio_domain.py @@ -61,77 +61,73 @@ def upgrade(): 1) update of the sagemaker_studio_domain table to include SageMaker Studio Domain VPC Information """ try: - envname = os.getenv('envname', 'local') - engine = get_engine(envname=envname).engine - bind = op.get_bind() session = orm.Session(bind=bind) - if has_table('sagemaker_studio_domain', engine): - print("Updating sagemaker_studio_domain table...") - op.alter_column( - 'sagemaker_studio_domain', - 'sagemakerStudioDomainID', - nullable=True, - existing_type=sa.String() - ) - op.alter_column( - 'sagemaker_studio_domain', - 'SagemakerStudioStatus', - nullable=True, - existing_type=sa.String() + print("Updating sagemaker_studio_domain table...") + op.alter_column( + 'sagemaker_studio_domain', + 'sagemakerStudioDomainID', + nullable=True, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'SagemakerStudioStatus', + nullable=True, + existing_type=sa.String() + ) + op.alter_column( + 'sagemaker_studio_domain', + 'RoleArn', + new_column_name='DefaultDomainRoleName', + nullable=False, + existing_type=sa.String() + ) + + op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False)) + op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True)) + op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False)) + + op.create_foreign_key( + "fk_sagemaker_studio_domain_env_uri", + "sagemaker_studio_domain", "environment", + ["environmentUri"], ["environmentUri"], + ) + + print("Update sagemaker_studio_domain table done.") + print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...") + + env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter( + and_( + EnvironmentParameter.key == "mlStudiosEnabled", + EnvironmentParameter.value == "true" ) - op.alter_column( - 'sagemaker_studio_domain', - 'RoleArn', - new_column_name='DefaultDomainRoleName', - nullable=False, - existing_type=sa.String() - ) - - op.add_column("sagemaker_studio_domain", Column("sagemakerStudioDomainName", sa.String(), nullable=False)) - op.add_column("sagemaker_studio_domain", Column("vpcType", sa.String(), nullable=True)) - op.add_column("sagemaker_studio_domain", Column("vpcId", sa.String(), nullable=True)) - op.add_column("sagemaker_studio_domain", Column("subnetIds", postgresql.ARRAY(sa.String()), nullable=True)) - op.add_column("sagemaker_studio_domain", Column("SamlGroupName", sa.String(), nullable=False)) - - op.create_foreign_key( - "fk_sagemaker_studio_domain_env_uri", - "sagemaker_studio_domain", "environment", - ["environmentUri"], ["environmentUri"], - ) - - print("Update sagemaker_studio_domain table done.") - print("Filling sagemaker_studio_domain table with environments with mlstudio enabled...") - - env_mlstudio_parameters: [EnvironmentParameter] = session.query(EnvironmentParameter).filter( - and_( - EnvironmentParameter.key == "mlStudiosEnabled", - EnvironmentParameter.value == "true" + ).all() + for param in env_mlstudio_parameters: + env: Environment = session.query(Environment).filter( + Environment.environmentUri == param.environmentUri + ).first() + + domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( + SagemakerStudioDomain.environmentUri == env.environmentUri + ).first() + if not domain: + domain = SagemakerStudioDomain( + label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + owner=env.owner, + description='No description provided', + environmentUri=env.environmentUri, + AWSAccountId=env.AwsAccountId, + region=env.region, + DefaultDomainRoleName="RoleSagemakerStudioUsers", + sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", + vpcType="unknown", + SamlGroupName=env.SamlGroupName ) - ).all() - for param in env_mlstudio_parameters: - env: Environment = session.query(Environment).filter( - Environment.environmentUri == param.environmentUri - ).first() - - domain: SagemakerStudioDomain = session.query(SagemakerStudioDomain).filter( - SagemakerStudioDomain.environmentUri == env.environmentUri - ).first() - if not domain: - domain = SagemakerStudioDomain( - label=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", - owner=env.owner, - description='No description provided', - environmentUri=env.environmentUri, - AWSAccountId=env.AwsAccountId, - region=env.region, - DefaultDomainRoleName="RoleSagemakerStudioUsers", - sagemakerStudioDomainName=f"SagemakerStudioDomain-{env.region}-{env.AwsAccountId}", - vpcType="unknown", - SamlGroupName=env.SamlGroupName - ) - session.add(domain) + session.add(domain) session.flush() session.commit() print("Fill of sagemaker_studio_domain table is done") diff --git a/deploy/cdk_exec_policy/cdkExecPolicy.yaml b/deploy/cdk_exec_policy/cdkExecPolicy.yaml index 08c79b9c8..21f113f2b 100644 --- a/deploy/cdk_exec_policy/cdkExecPolicy.yaml +++ b/deploy/cdk_exec_policy/cdkExecPolicy.yaml @@ -156,9 +156,7 @@ Resources: Effect: Allow Action: - 'sagemaker:*Tag*' - - 'sagemaker:CreateDomain' - - 'sagemaker:DeleteDomain' - - 'sagemaker:DescribeDomain' + - 'sagemaker:*Domain' - 'sagemaker:CreateApp' - 'sagemaker:CreateUserProfile' - 'sagemaker:DescribeUserProfile'