Skip to content

Commit

Permalink
Merge pull request #38 from MolarVerse/feature/beartype_exceptions
Browse files Browse the repository at this point in the history
Feature/beartype exceptions
  • Loading branch information
97gamjak authored May 6, 2024
2 parents a345b5b + 3585a95 commit e584ff4
Show file tree
Hide file tree
Showing 16 changed files with 223 additions and 52 deletions.
17 changes: 16 additions & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ permissions:

jobs:
pylint:

runs-on: ubuntu-latest

permissions: # Job-level permissions configuration starts here
contents: write # 'write' access to repository contents
pull-requests: write # 'write' access to pull requests

steps:
- uses: actions/checkout@master
with:
persist-credentials: false # otherwise, the token used is the GITHUB_TOKEN, instead of your personal access token.
fetch-depth: 0 # otherwise, there would be errors pushing refs to the destination repository.

- name: Setup Python
uses: actions/setup-python@v2
with:
Expand Down Expand Up @@ -97,5 +106,11 @@ jobs:
git add .github/.pylint_cache
git commit -m "Add .github/.pylint_cache on push event"
git push
- name: Push changes
if: github.event_name == 'push'
uses: ad-m/github-push-action@master
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
branch: ${{ github.ref }}

20 changes: 17 additions & 3 deletions PQAnalysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,26 @@

from PQAnalysis.utils.custom_logging import CustomLogger

beartype_this_package()

__base_path__ = Path(__file__).parent

__package_name__ = __name__

##################
# BEARTYPE SETUP #
##################

# TODO: change the default level to "RELEASE" after all changes are implemented
__beartype_default_level__ = "DEBUG"
__beartype_level__ = os.getenv(
"PQANALYSIS_BEARTYPE_LEVEL", __beartype_default_level__
)

if __beartype_level__.upper() == "DEBUG":
beartype_this_package()

#################
# LOGGING SETUP #
#################

logging_env_var = os.getenv("PQANALYSIS_LOGGING_LEVEL")

if logging_env_var and logging_env_var not in logging.getLevelNamesMapping():
Expand Down
9 changes: 6 additions & 3 deletions PQAnalysis/analysis/rdf/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,36 @@
from PQAnalysis.io import TrajectoryReader, RestartFileReader, MoldescriptorReader
from PQAnalysis.traj import MDEngineFormat
from PQAnalysis.topology import Topology
from PQAnalysis.type_checking import runtime_type_checking

from .rdf import RDF
from .rdf_input_file_reader import RDFInputFileReader
from .rdf_output_file_writer import RDFDataWriter, RDFLogWriter


@runtime_type_checking
def rdf(input_file: str, md_format: MDEngineFormat | str = MDEngineFormat.PQ):
"""
Calculates the radial distribution function (RDF) using a given input file.
This is just a wrapper function combining the underlying classes and functions.
For more information on the input file keys please
For more information on the input file keys please
visit :py:mod:`~PQAnalysis.analysis.rdf.rdfInputFileReader`.
For more information on the exact calculation of
the RDF please visit :py:class:`~PQAnalysis.analysis.rdf.rdf.RDF`.
Parameters
----------
input_file : str
The input file. For more information on the input file
The input file. For more information on the input file
keys please visit :py:mod:`~PQAnalysis.analysis.rdf.rdfInputFileReader`.
md_format : MDEngineFormat | str, optional
the format of the input trajectory. Default is "PQ".
the format of the input trajectory. Default is "PQ".
For more information on the supported formats please visit
:py:class:`~PQAnalysis.traj.formats.MDEngineFormat`.
"""

md_format = MDEngineFormat(md_format)

input_reader = RDFInputFileReader(input_file)
Expand Down
92 changes: 92 additions & 0 deletions PQAnalysis/type_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
A module for type checking of arguments passed to functions at runtime.
"""

import logging

from decorator import decorator
from beartype.door import is_bearable

from PQAnalysis.utils.custom_logging import setup_logger
from .types import (
Np1DIntArray,
Np2DIntArray,
Np1DNumberArray,
Np2DNumberArray,
Np3x3NumberArray,
NpnDNumberArray,
)

__logger_name__ = "PQAnalysis.TypeChecking"

if not logging.getLogger(__logger_name__).handlers:
logger = setup_logger(logging.getLogger(__logger_name__))
else:
logger = logging.getLogger(__logger_name__)


@decorator
def runtime_type_checking(func, *args, **kwargs):
"""
A decorator to check the type of the arguments passed to a function at runtime.
"""

# Get the type hints of the function
type_hints = func.__annotations__

# Check the type of each argument
for arg_name, arg_value in zip(func.__code__.co_varnames, args):
if arg_name in type_hints:
if not is_bearable(arg_value, type_hints[arg_name]):
logger.error(
_get_type_error_message(
arg_name,
arg_value,
type_hints[arg_name],
),
exception=TypeError,
)

# Check the type of each keyword argument
for kwarg_name, kwarg_value in kwargs.items():
if kwarg_name in type_hints:
if not is_bearable(kwarg_value, type_hints[kwarg_name]):
logger.error(
_get_type_error_message(
kwarg_name,
kwarg_value,
type_hints[kwarg_name],
),
exception=TypeError,
)

# Call the function
return func(*args, **kwargs)


def _get_type_error_message(arg_name, value, expected_type):
"""
Get the error message for a type error.
"""

actual_type = type(value)

header = (
f"Argument '{arg_name}' with {value=} should be "
f"of type {expected_type}, but got {actual_type}."
)

if expected_type is Np1DIntArray:
header += " Expected a 1D numpy integer array."
elif expected_type is Np2DIntArray:
header += " Expected a 2D numpy integer array."
elif expected_type is Np1DNumberArray:
header += " Expected a 1D numpy number array."
elif expected_type is Np2DNumberArray:
header += " Expected a 2D numpy number array."
elif expected_type is Np3x3NumberArray:
header += " Expected a 3x3 numpy number array."
elif expected_type is NpnDNumberArray:
header += " Expected an n-dimensional numpy number array."

return header
27 changes: 18 additions & 9 deletions PQAnalysis/utils/custom_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ def _log(self, # pylint: disable=arguments-differ
)

if level in [logging.CRITICAL, logging.ERROR]:

exception = exception or Exception

if self.isEnabledFor(logging.DEBUG):
back_tb = None

try:
if exception is not None:
raise exception

raise Exception # pylint: disable=broad-exception-raised
except Exception: # pylint: disable=broad-except
raise exception # pylint: disable=broad-exception-raised
except exception: # pylint: disable=broad-except
traceback = sys.exc_info()[2]
back_frame = traceback.tb_frame.f_back

Expand All @@ -140,12 +140,21 @@ def _log(self, # pylint: disable=arguments-differ
tb_lineno=back_frame.f_lineno
)

if exception is not None:
raise Exception(msg).with_traceback(back_tb)

raise exception(msg).with_traceback(back_tb)

sys.exit(1)
class DevNull:
"""
Dummy class to redirect the sys.stderr to /dev/null.
"""

def write(self, _):
"""
Dummy write method.
"""

sys.stderr = DevNull()

raise exception(msg) # pylint: disable=raise-missing-from

def _original_log(self,
level: Any,
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ markers =
topology
traj
io
analysis

testpaths =
tests
Expand Down
7 changes: 7 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""
Unit tests for the PQAnalysis package.
"""

import os

os.environ['PQANALYSIS_BEARTYPE_LEVEL'] = "RELEASE"
3 changes: 3 additions & 0 deletions tests/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import pytest

pytestmark = pytest.mark.analysis
Empty file added tests/analysis/rdf/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions tests/analysis/rdf/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
A module to test the RDF API.
"""

import pytest # pylint: disable=unused-import

from PQAnalysis.analysis.rdf.api import rdf
from PQAnalysis.type_checking import _get_type_error_message

from .. import pytestmark # pylint: disable=unused-import
from ...conftest import assert_logging_with_exception


class TestRDFAPI:
def test_wrong_param_types(self, caplog):
assert_logging_with_exception(
caplog=caplog,
logging_name="TypeChecking",
logging_level="ERROR",
message_to_test=_get_type_error_message(
"input_file", 1, str,
),
exception=TypeError,
function=rdf,
input_file=1,
)
4 changes: 4 additions & 0 deletions tests/analysis/rdf/test_rdfInputFileReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from PQAnalysis.analysis.rdf.rdf_input_file_reader import RDFInputFileReader
from PQAnalysis.io.input_file_reader.exceptions import InputFileError

# import topology marker
from .. import pytestmark # pylint: disable=unused-import
from ...conftest import assert_logging


class TestRDFInputFileReader:
@pytest.mark.parametrize("example_dir", ["rdf"], indirect=False)
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def assert_logging_with_exception(caplog, logging_name, logging_level, message_t
result = None
try:
result = function(*args, **kwargs)
except SystemExit:
except:
pass

record = caplog.records[0]
Expand Down
9 changes: 2 additions & 7 deletions tests/io/test_frameReader.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import pytest
import numpy as np

from beartype.roar import BeartypeException

from . import pytestmark

from PQAnalysis.io import FrameReader
from PQAnalysis.io.traj_file.exceptions import FrameReaderError
from PQAnalysis.core import Cell, Atom
from PQAnalysis.traj.exceptions import TrajectoryFormatError
from PQAnalysis.traj import TrajectoryFormat
from PQAnalysis.topology import Topology

from . import pytestmark


class TestFrameReader:

Expand Down Expand Up @@ -67,9 +65,6 @@ def test__read_scalar(self):
def test_read(self):
reader = FrameReader()

with pytest.raises(BeartypeException):
reader.read(["tmp"])

frame = reader.read(
"2 2.0 3.0 4.0 5.0 6.0 7.0\n\nh 1.0 2.0 3.0\no 2.0 2.0 2.0")
assert frame.n_atoms == 2
Expand Down
10 changes: 2 additions & 8 deletions tests/io/test_infoFileReader.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
import pytest

from beartype.roar import BeartypeException

from . import pytestmark

from PQAnalysis.io import InfoFileReader
from PQAnalysis.traj import MDEngineFormat
from PQAnalysis.traj.exceptions import MDEngineFormatError

from . import pytestmark


@pytest.mark.parametrize("example_dir", ["readInfoFile"], indirect=False)
def test__init__(test_with_data_dir):
with pytest.raises(FileNotFoundError) as exception:
InfoFileReader("tmp")
assert str(exception.value) == "File tmp not found."

with pytest.raises(BeartypeException) as exception:
InfoFileReader(
"md-01.info", engine_format=None)

with pytest.raises(MDEngineFormatError) as exception:
InfoFileReader(
"md-01.info", engine_format="tmp")
Expand Down
Loading

0 comments on commit e584ff4

Please sign in to comment.