Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrational Tests fixes #1744

Merged
merged 6 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/dataall/modules/s3_datasets/api/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
gql.Field(name='region', type=gql.String),
gql.Field(name='S3BucketName', type=gql.String),
gql.Field(name='GlueDatabaseName', type=gql.String),
gql.Field(name='GlueCrawlerName', type=gql.String),
gql.Field(name='IAMDatasetAdminRoleArn', type=gql.String),
gql.Field(name='KmsAlias', type=gql.String),
gql.Field(name='importedS3Bucket', type=gql.Boolean),
Expand Down
15 changes: 11 additions & 4 deletions tests_new/integration_tests/modules/s3_datasets/global_conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
sync_tables,
create_folder,
create_table_data_filter,
list_dataset_tables,
)

from tests_new.integration_tests.modules.datasets_base.queries import list_datasets
from integration_tests.aws_clients.s3 import S3Client as S3CommonClient
from integration_tests.modules.s3_datasets.aws_clients import S3Client, KMSClient, GlueClient, LakeFormationClient
Expand Down Expand Up @@ -179,8 +181,8 @@ def create_tables(client, dataset):
aws_session_token=creds['sessionToken'],
)
file_path = os.path.join(os.path.dirname(__file__), 'sample_data/csv_table/csv_sample.csv')
s3_client = S3Client(dataset_session, dataset.region)
glue_client = GlueClient(dataset_session, dataset.region)
s3_client = S3Client(dataset_session, dataset.restricted.region)
glue_client = GlueClient(dataset_session, dataset.restricted.region)
s3_client.upload_file_to_prefix(
local_file_path=file_path, s3_path=f'{dataset.restricted.S3BucketName}/integrationtest1'
)
Expand All @@ -198,8 +200,13 @@ def create_tables(client, dataset):
table_name='integrationtest2',
bucket=dataset.restricted.S3BucketName,
)
response = sync_tables(client, datasetUri=dataset.datasetUri)
return [table for table in response.get('nodes', []) if table.GlueTableName.startswith('integrationtest')]
sync_tables(client, datasetUri=dataset.datasetUri)
response = list_dataset_tables(client, datasetUri=dataset.datasetUri)
return [
table
for table in response.tables.get('nodes', [])
if table.restricted.GlueTableName.startswith('integrationtest')
]


def create_folders(client, dataset):
Expand Down
4 changes: 4 additions & 0 deletions tests_new/integration_tests/modules/s3_datasets/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
KmsAlias
S3BucketName
GlueDatabaseName
GlueCrawlerName
IAMDatasetAdminRoleArn
}
environment {
Expand Down Expand Up @@ -352,6 +353,7 @@ def update_folder(client, locationUri, input):
mutation updateDatasetStorageLocation($locationUri: String!, $input: ModifyDatasetStorageLocationInput!) {{
updateDatasetStorageLocation(locationUri: $locationUri, input: $input) {{
locationUri
label
}}
}}
""",
Expand Down Expand Up @@ -500,6 +502,8 @@ def list_dataset_tables(client, datasetUri):
tables {{
count
nodes {{
tableUri
label
restricted {{
GlueTableName
}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_start_crawler(client1, dataset_fixture_name, request):
dataset = request.getfixturevalue(dataset_fixture_name)
dataset_uri = dataset.datasetUri
response = start_glue_crawler(client1, datasetUri=dataset_uri, input={})
assert_that(response.Name).is_equal_to(dataset.GlueCrawlerName)
assert_that(response.Name).is_equal_to(dataset.restricted.GlueCrawlerName)
# TODO: check it can run successfully + check sending prefix - We should first implement it in API


Expand Down
12 changes: 0 additions & 12 deletions tests_new/integration_tests/modules/s3_datasets/test_s3_folders.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,6 @@ def test_get_folder(client1, folders_fixture_name, request):
assert_that(response.label).is_equal_to('labelSessionFolderA')


@pytest.mark.parametrize(
'folders_fixture_name',
['session_s3_dataset1_folders'],
)
def test_get_folder_unauthorized(client2, folders_fixture_name, request):
folders = request.getfixturevalue(folders_fixture_name)
folder = folders[0]
assert_that(get_folder).raises(GqlError).when_called_with(client2, locationUri=folder.locationUri).contains(
'UnauthorizedOperation', 'GET_DATASET_FOLDER', folder.locationUri
)


@pytest.mark.parametrize(*FOLDERS_FIXTURES_PARAMS)
def test_update_folder(client1, folders_fixture_name, request):
folders = request.getfixturevalue(folders_fixture_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def test_list_dataset_tables(client1, dataset_fixture_name, request):
dataset = request.getfixturevalue(dataset_fixture_name)
response = list_dataset_tables(client1, dataset.datasetUri)
assert_that(response.tables.count).is_greater_than_or_equal_to(2)
tables = [table for table in response.tables.get('nodes', []) if table.GlueTableName.startswith('integrationtest')]
tables = [
table
for table in response.tables.get('nodes', [])
if table.restricted.GlueTableName.startswith('integrationtest')
]
assert_that(len(tables)).is_equal_to(2)


Expand Down Expand Up @@ -116,11 +120,12 @@ def test_delete_table(client1, dataset_fixture_name, request):
aws_secret_access_key=creds['SessionKey'],
aws_session_token=creds['sessionToken'],
)
GlueClient(dataset_session, dataset.region).create_table(
GlueClient(dataset_session, dataset.restricted.region).create_table(
database_name=dataset.restricted.GlueDatabaseName, table_name='todelete', bucket=dataset.restricted.S3BucketName
)
response = sync_tables(client1, datasetUri=dataset.datasetUri)
table_uri = [table.tableUri for table in response.get('nodes', []) if table.label == 'todelete'][0]
sync_tables(client1, datasetUri=dataset.datasetUri)
response = list_dataset_tables(client1, datasetUri=dataset.datasetUri)
table_uri = [table.tableUri for table in response.tables.get('nodes', []) if table.label == 'todelete'][0]
response = delete_table(client1, table_uri)
assert_that(response).is_true()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def test_start_table_profiling(client1, dataset_fixture_name, tables_fixture_nam
table = tables[0]
dataset_uri = dataset.datasetUri
response = start_dataset_profiling_run(
client1, input={'datasetUri': dataset_uri, 'tableUri': table.tableUri, 'GlueTableName': table.GlueTableName}
client1,
input={'datasetUri': dataset_uri, 'tableUri': table.tableUri, 'GlueTableName': table.restricted.GlueTableName},
)
assert_that(response.datasetUri).is_equal_to(dataset_uri)
assert_that(response.GlueTableName).is_equal_to(table.GlueTableName)
assert_that(response.GlueTableName).is_equal_to(table.restricted.GlueTableName)


@pytest.mark.parametrize('dataset_fixture_name', ['session_s3_dataset1'])
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_get_table_profiling_run_by_confidentiality(client2, tables_fixture_name
table_uri = tables[0].tableUri
if confidentiality in ['Unclassified']:
response = get_table_profiling_run(client2, tableUri=table_uri)
assert_that(response.GlueTableName).is_equal_to(tables[0].GlueTableName)
assert_that(response.GlueTableName).is_equal_to(tables[0].restricted.GlueTableName)
else:
assert_that(get_table_profiling_run).raises(GqlError).when_called_with(client2, table_uri).contains(
'UnauthorizedOperation', 'GET_TABLE_PROFILING_METRICS'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ def principal1(request, group5, session_consumption_role_1):


@pytest.fixture(params=['Group', 'ConsumptionRole'])
def share_params_main(request, session_share_1, session_share_consrole_1, session_s3_dataset1):
def share_params_main(request, session_share_1, session_cross_acc_env_1, session_share_consrole_1, session_s3_dataset1):
if request.param == 'Group':
yield session_share_1, session_s3_dataset1
yield session_share_1, session_s3_dataset1, session_cross_acc_env_1
else:
yield session_share_consrole_1, session_s3_dataset1
yield session_share_consrole_1, session_s3_dataset1, session_cross_acc_env_1


@pytest.fixture(params=[(False, 'Group'), (True, 'Group'), (False, 'ConsumptionRole'), (True, 'ConsumptionRole')])
Expand Down Expand Up @@ -315,8 +315,10 @@ def persistent_role_share_1(


@pytest.fixture(params=['Group', 'ConsumptionRole'])
def persistent_share_params_main(request, persistent_role_share_1, persistent_group_share_1):
def persistent_share_params_main(
request, persistent_cross_acc_env_1, persistent_role_share_1, persistent_group_share_1
):
if request.param == 'Group':
yield persistent_group_share_1
yield persistent_group_share_1, persistent_cross_acc_env_1
else:
yield persistent_role_share_1
yield persistent_role_share_1, persistent_cross_acc_env_1
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,28 @@ def check_bucket_access(client, s3_client, bucket_name, should_have_access):


def check_accesspoint_access(client, s3_client, access_point_arn, item_uri, should_have_access):
folder = get_folder(client, item_uri)
if should_have_access:
folder = get_folder(client, item_uri)
assert_that(s3_client.list_accesspoint_folder_objects(access_point_arn, folder.S3Prefix + '/')).is_not_none()
else:
assert_that(get_folder).raises(Exception).when_called_with(client, item_uri).contains(
'is not authorized to perform: GET_DATASET_FOLDER'
)
assert_that(s3_client.list_accesspoint_folder_objects).raises(ClientError).when_called_with(
access_point_arn, folder.S3Prefix + '/'
).contains('AccessDenied')


def check_share_items_access(
client,
group,
shareUri,
share_environment,
consumption_role,
env_client,
):
share = get_share_object(client, shareUri)
dataset = share.dataset
principal_type = share.principal.principalType
if principal_type == 'Group':
credentials_str = get_environment_access_token(client, share.environment.environmentUri, group)
credentials_str = get_environment_access_token(client, share_environment.environmentUri, group)
credentials = json.loads(credentials_str)
session = boto3.Session(
aws_access_key_id=credentials['AccessKey'],
Expand All @@ -169,7 +170,7 @@ def check_share_items_access(
f'arn:aws:s3:{dataset.region}:{dataset.AwsAccountId}:accesspoint/{consumption_data.s3AccessPointName}'
)
if principal_type == 'Group':
workgroup = athena_client.get_env_work_group(share.environment.label)
workgroup = athena_client.get_env_work_group(share_environment.label)
athena_workgroup_output_location = None
else:
workgroup = 'primary'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def test_reject_share(client1, client5, session_cross_acc_env_1, session_s3_data


def test_change_share_purpose(client5, share_params_main):
share, dataset = share_params_main
share, dataset, _ = share_params_main
change_request_purpose = update_share_request_reason(client5, share.shareUri, 'new purpose')
assert_that(change_request_purpose).is_true()
updated_share = get_share_object(client5, share.shareUri)
Expand All @@ -153,37 +153,42 @@ def test_submit_object(client5, share_params_all):

@pytest.mark.dependency(name='share_approved', depends=['share_submitted'])
def test_approve_share(client1, share_params_main):
share, dataset = share_params_main
share, dataset, _ = share_params_main
check_approve_share_object(client1, share.shareUri)


@pytest.mark.dependency(name='share_succeeded', depends=['share_approved'])
def test_share_succeeded(client1, share_params_main):
share, dataset = share_params_main
share, dataset, _ = share_params_main
check_share_succeeded(client1, share.shareUri, check_contains_all_item_types=True)


@pytest.mark.dependency(name='share_verified', depends=['share_succeeded'])
def test_verify_share_items(client1, share_params_main):
share, dataset = share_params_main
share, dataset, _ = share_params_main
check_verify_share_items(client1, share.shareUri)


@pytest.mark.dependency(depends=['share_verified'])
def test_check_item_access(
client5, session_cross_acc_env_1_aws_client, share_params_main, group5, session_consumption_role_1
):
share, dataset = share_params_main
share, dataset, share_environment = share_params_main
check_share_items_access(
client5, group5, share.shareUri, session_consumption_role_1, session_cross_acc_env_1_aws_client
client5,
group5,
share.shareUri,
share_environment,
session_consumption_role_1,
session_cross_acc_env_1_aws_client,
)


@pytest.mark.dependency(name='unhealthy_items', depends=['share_verified'])
def test_unhealthy_items(
client5, session_cross_acc_env_1_aws_client, session_cross_acc_env_1_integration_role_arn, share_params_main
):
share, _ = share_params_main
share, _, _ = share_params_main
iam = session_cross_acc_env_1_aws_client.resource('iam')
principal_role = iam.Role(share.principal.principalRoleName)
# break s3 by removing policies
Expand All @@ -209,7 +214,7 @@ def test_unhealthy_items(

@pytest.mark.dependency(depends=['share_approved'])
def test_reapply_unauthoried(client5, share_params_main):
share, _ = share_params_main
share, _, _ = share_params_main
share_uri = share.shareUri
share_object = get_share_object(client5, share_uri)
item_uris = [item.shareItemUri for item in share_object['items'].nodes]
Expand All @@ -220,7 +225,7 @@ def test_reapply_unauthoried(client5, share_params_main):

@pytest.mark.dependency(depends=['share_approved'])
def test_reapply(client1, share_params_main):
share, _ = share_params_main
share, _, _ = share_params_main
share_uri = share.shareUri
share_object = get_share_object(client1, share_uri)
item_uris = [item.shareItemUri for item in share_object['items'].nodes]
Expand All @@ -233,7 +238,7 @@ def test_reapply(client1, share_params_main):

@pytest.mark.dependency(name='share_revoked', depends=['share_succeeded'])
def test_revoke_share(client1, share_params_main):
share, dataset = share_params_main
share, dataset, _ = share_params_main
check_share_ready(client1, share.shareUri)
revoke_and_check_all_shared_items(client1, share.shareUri, check_contains_all_item_types=True)

Expand All @@ -242,8 +247,13 @@ def test_revoke_share(client1, share_params_main):
def test_revoke_succeeded(
client1, client5, session_cross_acc_env_1_aws_client, share_params_main, group5, session_consumption_role_1
):
share, dataset = share_params_main
share, dataset, share_environment = share_params_main
check_all_items_revoke_job_succeeded(client1, share.shareUri, check_contains_all_item_types=True)
check_share_items_access(
client5, group5, share.shareUri, session_consumption_role_1, session_cross_acc_env_1_aws_client
client5,
group5,
share.shareUri,
share_environment,
session_consumption_role_1,
session_cross_acc_env_1_aws_client,
)
Loading
Loading