Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff code auto-format #1105

Merged
merged 8 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
8 changes: 4 additions & 4 deletions .github/workflows/flake8.yml → .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: flake8
name: ruff

on:
workflow_dispatch:
Expand Down Expand Up @@ -27,6 +27,6 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install isort
python -m pip install flake8
- name: flake8
run: python -m flake8 --exclude cdk.out,blueprints --ignore E402,E501,F841,W503,F405,F403,F401,E712,E203 backend/
python -m pip install ruff
- name: ruff
run: ruff check
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,4 @@ npm-debug.log*
yarn-debug.log*
yarn-error.log*
.idea
/.ruff_cache/
14 changes: 9 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
repos:
- repo: https://github.com/psf/black
rev: 22.8.0
hooks:
- id: black
entry: black --skip-string-normalization backend/dataall
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.2
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0
hooks:
Expand Down
5 changes: 3 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ install-tests:
pip install -r tests/requirements.txt

lint:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't makefile target tool specific, let's keep it as lint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we can keep both? I don't want the name to be confusing, since it's just an alias

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we were running flake8 and we called the target lint and not flake8 which I think is correct because it's abstracting whatever underlying linter we use and also that way we wouldn't have to change the pipelines code.

pip install flake8
python -m flake8 --exclude cdk.out,blueprints --ignore E402,E501,F841,W503,F405,F403,F401,E712,E203 backend/
pip install ruff
ruff check --fix
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we run ruff format here as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw forgot to say that we should be taking into account both check and format exit codes and fail the target if either of them fail.

ruff format

bandit:
pip install bandit
Expand Down
50 changes: 21 additions & 29 deletions backend/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def get_cognito_groups(claims):
groups = list()
saml_groups = claims.get('custom:saml.groups', '')
if len(saml_groups):
groups: list = (
saml_groups.replace('[', '').replace(']', '').replace(', ', ',').split(',')
)
groups: list = saml_groups.replace('[', '').replace(']', '').replace(', ', ',').split(',')
cognito_groups = claims.get('cognito:groups', '')
if len(cognito_groups):
groups.extend(cognito_groups.split(','))
Expand Down Expand Up @@ -138,20 +136,16 @@ def handler(event, context):
log.debug('username is %s', username)
try:
groups = []
if (os.environ.get('custom_auth', None)):
if os.environ.get('custom_auth', None):
groups.extend(get_custom_groups(user_id))
else:
groups.extend(get_cognito_groups(claims))
log.debug('groups are %s', ",".join(groups))
log.debug('groups are %s', ','.join(groups))
with ENGINE.scoped_session() as session:
for group in groups:
policy = TenantPolicy.find_tenant_policy(
session, group, 'dataall'
)
policy = TenantPolicy.find_tenant_policy(session, group, 'dataall')
if not policy:
print(
f'No policy found for Team {group}. Attaching TENANT_ALL permissions'
)
print(f'No policy found for Team {group}. Attaching TENANT_ALL permissions')
TenantPolicy.attach_group_tenant_policy(
session=session,
group=group,
Expand All @@ -174,9 +168,11 @@ def handler(event, context):

# Determine if there are any Operations that Require ReAuth From SSM Parameter
try:
reauth_apis = ParameterStoreManager.get_parameter_value(region=os.getenv('AWS_REGION', 'eu-west-1'), parameter_path=f"/dataall/{ENVNAME}/reauth/apis").split(',')
except Exception as e:
log.info("No ReAuth APIs Found in SSM")
reauth_apis = ParameterStoreManager.get_parameter_value(
region=os.getenv('AWS_REGION', 'eu-west-1'), parameter_path=f'/dataall/{ENVNAME}/reauth/apis'
).split(',')
except Exception:
log.info('No ReAuth APIs Found in SSM')
reauth_apis = None
else:
raise Exception(f'Could not initialize user context from event {event}')
Expand All @@ -187,23 +183,21 @@ def handler(event, context):
if reauth_apis and query.get('operationName', None) in reauth_apis:
now = datetime.datetime.now(datetime.timezone.utc)
try:
auth_time_datetime = datetime.datetime.fromtimestamp(int(claims["auth_time"]), tz=datetime.timezone.utc)
auth_time_datetime = datetime.datetime.fromtimestamp(int(claims['auth_time']), tz=datetime.timezone.utc)
if auth_time_datetime + datetime.timedelta(minutes=REAUTH_TTL) < now:
raise Exception("ReAuth")
raise Exception('ReAuth')
except Exception as e:
log.info(f'ReAuth Required for User {username} on Operation {query.get("operationName", "")}, Error: {e}')
response = {
"data": {query.get('operationName', 'operation') : None},
"errors": [
'data': {query.get('operationName', 'operation'): None},
'errors': [
{
"message": f"ReAuth Required To Perform This Action {query.get('operationName', '')}",
"locations": None,
"path": [query.get('operationName', '')],
"extensions": {
"code": "REAUTH"
}
'message': f"ReAuth Required To Perform This Action {query.get('operationName', '')}",
'locations': None,
'path': [query.get('operationName', '')],
'extensions': {'code': 'REAUTH'},
}
]
],
}
return {
'statusCode': 401,
Expand All @@ -213,12 +207,10 @@ def handler(event, context):
'Access-Control-Allow-Headers': '*',
'Access-Control-Allow-Methods': '*',
},
'body': json.dumps(response)
'body': json.dumps(response),
}

success, response = graphql_sync(
schema=executable_schema, data=query, context_value=app_context
)
success, response = graphql_sync(schema=executable_schema, data=query, context_value=app_context)

dispose_context()
response = json.dumps(response)
Expand Down
22 changes: 6 additions & 16 deletions backend/cdkproxymain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
logger = logging.getLogger('cdksass')

ENVNAME = os.getenv('envname', 'local')
logger.warning(
f"Application started for envname= `{ENVNAME}` DH_DOCKER_VERSION:{os.environ.get('DH_DOCKER_VERSION')}"
)
logger.warning(f"Application started for envname= `{ENVNAME}` DH_DOCKER_VERSION:{os.environ.get('DH_DOCKER_VERSION')}")


def connect():
Expand All @@ -30,7 +28,7 @@ def connect():
with engine.scoped_session() as session:
orgs = session.query(Organization).all()
return engine
except Exception as e:
except Exception:
raise Exception('Connection Error')


Expand All @@ -52,11 +50,7 @@ def check_creds(response: Response):
logger.info('GET /awscreds')
try:
region = os.getenv('AWS_REGION', 'eu-west-1')
sts = boto3.client(
'sts',
region_name=region,
endpoint_url=f"https://sts.{region}.amazonaws.com"
)
sts = boto3.client('sts', region_name=region, endpoint_url=f'https://sts.{region}.amazonaws.com')
data = sts.get_caller_identity()
return {
'DH_DOCKER_VERSION': os.environ.get('DH_DOCKER_VERSION'),
Expand Down Expand Up @@ -88,7 +82,7 @@ def check_connect(response: Response):
return {
'DH_DOCKER_VERSION': os.environ.get('DH_DOCKER_VERSION'),
'_ts': datetime.now().isoformat(),
'message': f"Connected to database for environment {ENVNAME}({engine.dbconfig.host})",
'message': f'Connected to database for environment {ENVNAME}({engine.dbconfig.host})',
}
except Exception as e:
logger.exception('DBCONNECTIONERROR')
Expand Down Expand Up @@ -123,9 +117,7 @@ def check_cdk_installed(response: Response):


@app.post('/stack/{stackid}', status_code=status.HTTP_202_ACCEPTED)
async def create_stack(
stackid: str, background_tasks: BackgroundTasks, response: Response
):
async def create_stack(stackid: str, background_tasks: BackgroundTasks, response: Response):
"""Deploys or updates the stack"""
logger.info(f'POST /stack/{stackid}')
try:
Expand Down Expand Up @@ -174,9 +166,7 @@ async def create_stack(


@app.delete('/stack/{stackid}', status_code=status.HTTP_202_ACCEPTED)
async def delete_stack(
stackid: str, background_tasks: BackgroundTasks, response: Response
):
async def delete_stack(stackid: str, background_tasks: BackgroundTasks, response: Response):
"""
Deletes the stack
"""
Expand Down
1 change: 0 additions & 1 deletion backend/dataall/base/api/gql/graphql_directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@ def gql(self, with_directives=True):


if __name__ == '__main__':

uri = DirectiveArgs(name='uri', model='X', param=2, bool=True)
print(uri.gql())
6 changes: 2 additions & 4 deletions backend/dataall/base/api/gql/graphql_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .graphql_argument import Argument
from .graphql_enum import GraphqlEnum
from .graphql_scalar import *
from .graphql_scalar import Scalar
from .graphql_type import ObjectType
from .graphql_type_modifiers import ArrayType, NonNullableType, TypeModifier
from .graphql_union_type import Union
Expand Down Expand Up @@ -65,9 +65,7 @@ def gql(self, with_directives=True) -> str:
return f'{gql}'

def directive(self, directive_name):
return next(
filter(lambda d: d.name == directive_name, self.directives or []), None
)
return next(filter(lambda d: d.name == directive_name, self.directives or []), None)

def has_directive(self, directive_name):
return self.directive(directive_name=directive_name) is not None
Expand Down
4 changes: 3 additions & 1 deletion backend/dataall/base/api/gql/graphql_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .graphql_type import ObjectType
from dataall.base.api.gql import ObjectType
from dataall.base.api.gql.graphql_type_modifiers import NonNullableType
from dataall.base.api.gql.graphql_scalar import String


class Interface(ObjectType):
Expand Down
12 changes: 6 additions & 6 deletions backend/dataall/base/api/gql/graphql_type_modifiers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .graphql_enum import GraphqlEnum as Enum
from .graphql_input import InputType
from .graphql_scalar import Scalar
from .graphql_type import ObjectType
from .ref import Ref
from .thunk import Thunk
from dataall.base.api.gql.graphql_enum import GraphqlEnum as Enum
from dataall.base.api.gql.graphql_input import InputType
from dataall.base.api.gql.graphql_scalar import Scalar
from dataall.base.api.gql.graphql_type import ObjectType
from dataall.base.api.gql.ref import Ref
from dataall.base.api.gql.thunk import Thunk


class TypeModifier:
Expand Down
2 changes: 1 addition & 1 deletion backend/dataall/base/api/gql/graphql_union_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class UnionTypeRegistry(ABC):

@classmethod
def types(cls):
raise NotImplementedError("Types method is not implemented")
raise NotImplementedError('Types method is not implemented')


@cache_instances
Expand Down
15 changes: 4 additions & 11 deletions backend/dataall/base/api/gql/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .graphql_field import Field
from .graphql_scalar import *
from .graphql_scalar import String
from .graphql_type import ObjectType


Expand All @@ -18,17 +18,13 @@ def update_context(self, key, value):

def ensure_query(self):
if not self.type('Query'):
self.add_type(
ObjectType(name='Query', fields=[Field(name='test', type=String)])
)
self.add_type(ObjectType(name='Query', fields=[Field(name='test', type=String)]))
elif not len(self.type('Query').fields):
self.type('Query').add_field(field=Field(name='test', type=String))

def ensure_mutation(self):
if not self.type('Mutation'):
self.add_type(
ObjectType(name='Mutation', fields=[Field(name='test', type=String)])
)
self.add_type(ObjectType(name='Mutation', fields=[Field(name='test', type=String)]))
elif not len(self.type('Mutation').fields):
self.type('Mutation').add_field(field=Field(name='test', type=String))

Expand Down Expand Up @@ -117,9 +113,6 @@ def resolve(self, path, context, source, **kwargs):


if __name__ == '__main__':

schema = Schema(
types=[ObjectType(name='Account', fields=[Field(name='uri', type=String)])]
)
schema = Schema(types=[ObjectType(name='Account', fields=[Field(name='uri', type=String)])])

print(schema.gql())
2 changes: 1 addition & 1 deletion backend/dataall/base/api/gql/thunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def gql(self):

if __name__ == '__main__':
from ..gql import Field
from ..gql import *
from ..gql import String

Foo = Field(name='foo', type=String)
t = Thunk(lambda: Foo)
Expand Down
2 changes: 1 addition & 1 deletion backend/dataall/base/api/gql/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import types

from .graphql_type_modifiers import *
from .graphql_type_modifiers import ArrayType, Enum, InputType, NonNullableType, ObjectType, Scalar
from .ref import Ref
from .thunk import Thunk

Expand Down
8 changes: 2 additions & 6 deletions backend/dataall/base/api/gql/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,7 @@ def visit(self):
for object_type in self.schema.types:
self.enter_type(object_type=object_type, schema=self.schema)
for field in object_type.fields:
self.enter_field(
field=field, object_type=object_type, schema=self.schema
)
self.leave_field(
field=field, object_type=object_type, schema=self.schema
)
self.enter_field(field=field, object_type=object_type, schema=self.schema)
self.leave_field(field=field, object_type=object_type, schema=self.schema)
self.leave_type(object_type=object_type, schema=self.schema)
self.leave_schema(schema=self.schema)
14 changes: 3 additions & 11 deletions backend/dataall/base/aws/cognito.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class Cognito(ServiceProvider):

def __init__(self):
self.client = boto3.client('cognito-idp', region_name=os.getenv('AWS_REGION', 'eu-west-1'))

Expand All @@ -19,10 +18,7 @@ def get_user_emailids_from_group(self, groupName):
ssm = boto3.client('ssm', region_name=os.getenv('AWS_REGION', 'eu-west-1'))
user_pool_id = ssm.get_parameter(Name=parameter_path)['Parameter']['Value']
paginator = self.client.get_paginator('list_users_in_group')
pages = paginator.paginate(
UserPoolId=user_pool_id,
GroupName=groupName
)
pages = paginator.paginate(UserPoolId=user_pool_id, GroupName=groupName)
cognito_user_list = []
for page in pages:
cognito_user_list += page['Users']
Expand All @@ -38,9 +34,7 @@ def get_user_emailids_from_group(self, groupName):
if envname in ['local', 'dkrcompose']:
log.error('Local development environment does not support Cognito')
return ['[email protected]']
log.error(
f'Failed to get email ids for Cognito group {groupName} due to {e}'
)
log.error(f'Failed to get email ids for Cognito group {groupName} due to {e}')
raise e
else:
return group_email_ids
Expand All @@ -58,9 +52,7 @@ def list_groups(self, envname: str, region: str):
for page in pages:
groups += [gr['GroupName'] for gr in page['Groups']]
except Exception as e:
log.error(
f'Failed to list groups of user pool {user_pool_id} due to {e}'
)
log.error(f'Failed to list groups of user pool {user_pool_id} due to {e}')
raise e
return groups

Expand Down
Loading
Loading