Skip to content

Commit

Permalink
correct queries and attributed
Browse files Browse the repository at this point in the history
  • Loading branch information
Sofia Sazonova committed Dec 17, 2024
1 parent eb4915e commit b4df733
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 18 deletions.
1 change: 1 addition & 0 deletions backend/dataall/modules/s3_datasets/api/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
gql.Field(name='region', type=gql.String),
gql.Field(name='S3BucketName', type=gql.String),
gql.Field(name='GlueDatabaseName', type=gql.String),
gql.Field(name='GlueCrawlerName', type=gql.String),
gql.Field(name='IAMDatasetAdminRoleArn', type=gql.String),
gql.Field(name='KmsAlias', type=gql.String),
gql.Field(name='importedS3Bucket', type=gql.Boolean),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
sync_tables,
create_folder,
create_table_data_filter,
list_dataset_tables
)

from tests_new.integration_tests.modules.datasets_base.queries import list_datasets
from integration_tests.aws_clients.s3 import S3Client as S3CommonClient
from integration_tests.modules.s3_datasets.aws_clients import S3Client, KMSClient, GlueClient, LakeFormationClient
Expand Down Expand Up @@ -179,8 +181,8 @@ def create_tables(client, dataset):
aws_session_token=creds['sessionToken'],
)
file_path = os.path.join(os.path.dirname(__file__), 'sample_data/csv_table/csv_sample.csv')
s3_client = S3Client(dataset_session, dataset.region)
glue_client = GlueClient(dataset_session, dataset.region)
s3_client = S3Client(dataset_session, dataset.restricted.region)
glue_client = GlueClient(dataset_session, dataset.restricted.region)
s3_client.upload_file_to_prefix(
local_file_path=file_path, s3_path=f'{dataset.restricted.S3BucketName}/integrationtest1'
)
Expand All @@ -198,8 +200,9 @@ def create_tables(client, dataset):
table_name='integrationtest2',
bucket=dataset.restricted.S3BucketName,
)
response = sync_tables(client, datasetUri=dataset.datasetUri)
return [table for table in response.get('nodes', []) if table.GlueTableName.startswith('integrationtest')]
sync_tables(client, datasetUri=dataset.datasetUri)
response = list_dataset_tables(client, datasetUri=dataset.datasetUri)
return [table for table in response.tables.get('nodes', []) if table.restricted.GlueTableName.startswith('integrationtest')]


def create_folders(client, dataset):
Expand Down
10 changes: 9 additions & 1 deletion tests_new/integration_tests/modules/s3_datasets/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
KmsAlias
S3BucketName
GlueDatabaseName
GlueCrawlerName
IAMDatasetAdminRoleArn
}
environment {
Expand Down Expand Up @@ -352,6 +353,7 @@ def update_folder(client, locationUri, input):
mutation updateDatasetStorageLocation($locationUri: String!, $input: ModifyDatasetStorageLocationInput!) {{
updateDatasetStorageLocation(locationUri: $locationUri, input: $input) {{
locationUri
label
}}
}}
""",
Expand All @@ -375,7 +377,11 @@ def get_folder(client, locationUri):
""",
}
response = client.query(query=query)
return response.data.getDatasetStorageLocation
print(response)
if not response.errors:
return response.data.getDatasetStorageLocation
else:
return response.errors[0].message


## Tables Queries/Mutations
Expand Down Expand Up @@ -500,6 +506,8 @@ def list_dataset_tables(client, datasetUri):
tables {{
count
nodes {{
tableUri
label
restricted {{
GlueTableName
}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_start_crawler(client1, dataset_fixture_name, request):
dataset = request.getfixturevalue(dataset_fixture_name)
dataset_uri = dataset.datasetUri
response = start_glue_crawler(client1, datasetUri=dataset_uri, input={})
assert_that(response.Name).is_equal_to(dataset.GlueCrawlerName)
assert_that(response.Name).is_equal_to(dataset.restricted.GlueCrawlerName)
# TODO: check it can run successfully + check sending prefix - We should first implement it in API


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def test_get_folder(client1, folders_fixture_name, request):
def test_get_folder_unauthorized(client2, folders_fixture_name, request):
folders = request.getfixturevalue(folders_fixture_name)
folder = folders[0]
assert_that(get_folder).raises(GqlError).when_called_with(client2, locationUri=folder.locationUri).contains(
to_be_error = get_folder(client2, locationUri=folder.locationUri)
assert_that(to_be_error).contains(
'UnauthorizedOperation', 'GET_DATASET_FOLDER', folder.locationUri
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_list_dataset_tables(client1, dataset_fixture_name, request):
dataset = request.getfixturevalue(dataset_fixture_name)
response = list_dataset_tables(client1, dataset.datasetUri)
assert_that(response.tables.count).is_greater_than_or_equal_to(2)
tables = [table for table in response.tables.get('nodes', []) if table.GlueTableName.startswith('integrationtest')]
tables = [table for table in response.tables.get('nodes', []) if table.restricted.GlueTableName.startswith('integrationtest')]
assert_that(len(tables)).is_equal_to(2)


Expand Down Expand Up @@ -116,11 +116,12 @@ def test_delete_table(client1, dataset_fixture_name, request):
aws_secret_access_key=creds['SessionKey'],
aws_session_token=creds['sessionToken'],
)
GlueClient(dataset_session, dataset.region).create_table(
GlueClient(dataset_session, dataset.restricted.region).create_table(
database_name=dataset.restricted.GlueDatabaseName, table_name='todelete', bucket=dataset.restricted.S3BucketName
)
response = sync_tables(client1, datasetUri=dataset.datasetUri)
table_uri = [table.tableUri for table in response.get('nodes', []) if table.label == 'todelete'][0]
sync_tables(client1, datasetUri=dataset.datasetUri)
response = list_dataset_tables(client1, datasetUri=dataset.datasetUri)
table_uri = [table.tableUri for table in response.tables.get('nodes', []) if table.label == 'todelete'][0]
response = delete_table(client1, table_uri)
assert_that(response).is_true()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def test_start_table_profiling(client1, dataset_fixture_name, tables_fixture_nam
table = tables[0]
dataset_uri = dataset.datasetUri
response = start_dataset_profiling_run(
client1, input={'datasetUri': dataset_uri, 'tableUri': table.tableUri, 'GlueTableName': table.GlueTableName}
client1, input={'datasetUri': dataset_uri, 'tableUri': table.tableUri, 'GlueTableName': table.restricted.GlueTableName}
)
assert_that(response.datasetUri).is_equal_to(dataset_uri)
assert_that(response.GlueTableName).is_equal_to(table.GlueTableName)
assert_that(response.GlueTableName).is_equal_to(table.restricted.GlueTableName)


@pytest.mark.parametrize('dataset_fixture_name', ['session_s3_dataset1'])
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_get_table_profiling_run_by_confidentiality(client2, tables_fixture_name
table_uri = tables[0].tableUri
if confidentiality in ['Unclassified']:
response = get_table_profiling_run(client2, tableUri=table_uri)
assert_that(response.GlueTableName).is_equal_to(tables[0].GlueTableName)
assert_that(response.GlueTableName).is_equal_to(tables[0].restricted.GlueTableName)
else:
assert_that(get_table_profiling_run).raises(GqlError).when_called_with(client2, table_uri).contains(
'UnauthorizedOperation', 'GET_TABLE_PROFILING_METRICS'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,20 @@ def check_share_items_access(
)
elif principal_type == 'ConsumptionRole':
session = STSClient(
role_arn=consumption_role.IAMRoleArn, region=dataset.region, session_name='ConsumptionRole'
role_arn=consumption_role.IAMRoleArn, region=dataset.restricted.region, session_name='ConsumptionRole'
).get_role_session(env_client)
else:
raise Exception('wrong principal type')

s3_client = S3Client(session, dataset.region)
athena_client = AthenaClient(session, dataset.region)
s3_client = S3Client(session, dataset.restricted.region)
athena_client = AthenaClient(session, dataset.restricted.region)

consumption_data = get_s3_consumption_data(client, shareUri)
items = share['items'].nodes

glue_db = consumption_data.sharedGlueDatabase
access_point_arn = (
f'arn:aws:s3:{dataset.region}:{dataset.AwsAccountId}:accesspoint/{consumption_data.s3AccessPointName}'
f'arn:aws:s3:{dataset.restricted.region}:{dataset.AwsAccountId}:accesspoint/{consumption_data.s3AccessPointName}'
)
if principal_type == 'Group':
workgroup = athena_client.get_env_work_group(share_environment.label)
Expand Down

0 comments on commit b4df733

Please sign in to comment.