Skip to content

Commit

Permalink
Provide a real script with multiple sub commands
Browse files Browse the repository at this point in the history
  • Loading branch information
credbbl committed Dec 15, 2023
1 parent d031feb commit fe11577
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 72 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ packages = [
{ include = "glvd", from = "src" },
]

[tool.poetry.scripts]
glvd = 'glvd.cli.__main__:main'

[tool.poetry.dependencies]
python = ">=3.11"
asyncpg = ">=0.28"
Expand Down
83 changes: 83 additions & 0 deletions src/glvd/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# SPDX-License-Identifier: MIT

from __future__ import annotations

import argparse
import dataclasses
from collections.abc import (
Callable,
Iterable,
)


@dataclasses.dataclass
class _ActionWrapper:
args: tuple
kw: dict


class Cli:
parser: argparse.ArgumentParser
subparsers: argparse._SubParsersAction

def __init__(self) -> None:
self.parser = argparse.ArgumentParser(
allow_abbrev=False,
prog='glvd',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self.subparsers = self.parser.add_subparsers(
help='sub-command help',
)

def add_argument(self, *args, **kw) -> _ActionWrapper:
return _ActionWrapper(args, kw)

def register(
self,
name: str,
arguments: Iterable[_ActionWrapper],
usage: str = '%(prog)s',
epilog: str | None = None,
) -> Callable:
parser_main = argparse.ArgumentParser(
allow_abbrev=False,
prog=f'glvd.{name}',
usage=usage,
epilog=epilog,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser_sub = self.subparsers.add_parser(
name=name,
usage=usage,
epilog=epilog,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

for p in (parser_main, parser_sub):
for w in arguments:
p.add_argument(*w.args, **w.kw)

def wrap(func: Callable) -> Callable:
parser_sub.set_defaults(func=func)

def run() -> None:
args = parser_main.parse_args()
func(**vars(args))

return run

return wrap

def main(self) -> None:
args = self.parser.parse_args()
v = vars(args)
func = v.pop('func', None)
if func:
func(**v)
else:
self.parser.print_help()


cli = Cli()
22 changes: 22 additions & 0 deletions src/glvd/cli/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-License-Identifier: MIT

from __future__ import annotations

from . import cli

# Import to register all the commands
from . import ( # noqa: F401
combine_all,
combine_deb,
ingest_debsec,
ingest_debsrc,
ingest_nvd,
)


def main() -> None:
cli.main()


if __name__ == '__main__':
main()
34 changes: 24 additions & 10 deletions src/glvd/cli/combine_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,34 @@
)

from ..database import Base, AllCve
from . import cli


logger = logging.getLogger(__name__)


class CombineDeb:
class CombineAll:
@staticmethod
@cli.register(
'combine-all',
arguments=[
cli.add_argument(
'--database',
default='postgresql+asyncpg:///',
help='the database to use, must use asyncio compatible SQLAlchemy driver',
),
cli.add_argument(
'--debug',
action='store_true',
help='enable debug output',
),
]
)
def run(database: str, debug: bool) -> None:
logging.basicConfig(level=debug and logging.DEBUG or logging.INFO)
engine = create_async_engine(database, echo=debug)
asyncio.run(CombineAll()(engine))

stmt_combine_new = (
text('''
SELECT
Expand Down Expand Up @@ -103,12 +125,4 @@ async def __call__(


if __name__ == '__main__':
import argparse
logging.basicConfig(level=logging.DEBUG)
parser = argparse.ArgumentParser()
args = parser.parse_args()
engine = create_async_engine(
"postgresql+asyncpg:///",
)
main = CombineDeb()
asyncio.run(main(engine))
CombineAll.run()
32 changes: 23 additions & 9 deletions src/glvd/cli/combine_deb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,34 @@
from ..database import Base, DistCpe, DebCve
from ..data.cpe import Cpe, CpeOtherDebian
from ..data.cvss import CvssSeverity
from . import cli


logger = logging.getLogger(__name__)


class CombineDeb:
@staticmethod
@cli.register(
'combine-deb',
arguments=[
cli.add_argument(
'--database',
default='postgresql+asyncpg:///',
help='the database to use, must use asyncio compatible SQLAlchemy driver',
),
cli.add_argument(
'--debug',
action='store_true',
help='enable debug output',
),
]
)
def run(database: str, debug: bool) -> None:
logging.basicConfig(level=debug and logging.DEBUG or logging.INFO)
engine = create_async_engine(database, echo=debug)
asyncio.run(CombineDeb()(engine))

stmt_combine_new = (
text('''
SELECT
Expand Down Expand Up @@ -207,12 +229,4 @@ async def __call__(


if __name__ == '__main__':
import argparse
logging.basicConfig(level=logging.DEBUG)
parser = argparse.ArgumentParser()
args = parser.parse_args()
engine = create_async_engine(
"postgresql+asyncpg:///",
)
main = CombineDeb()
asyncio.run(main(engine))
CombineDeb.run()
56 changes: 35 additions & 21 deletions src/glvd/cli/ingest_debsec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,46 @@
from ..database import Base, DistCpe, DebsecCve
from ..data.debsec_cve import DebsecCveFile
from ..data.dist_cpe import DistCpeMapper
from . import cli


logger = logging.getLogger(__name__)


class IngestDebsec:
@staticmethod
@cli.register(
'ingest-debsec',
arguments=[
cli.add_argument(
'cpe_product',
choices=sorted(DistCpeMapper.keys()),
help=f'CPE product used for data, supported: {" ".join(sorted(DistCpeMapper.keys()))}',
metavar='CPE_PRODUCT',
),
cli.add_argument(
'dir',
help='data directory out of https://salsa.debian.org/security-tracker-team/security-tracker',
metavar='DEBSEC',
type=Path,
),
cli.add_argument(
'--database',
default='postgresql+asyncpg:///',
help='the database to use, must use asyncio compatible SQLAlchemy driver',
),
cli.add_argument(
'--debug',
action='store_true',
help='enable debug output',
),
]
)
def run(cpe_product: str, dir: Path, database: str, debug: bool) -> None:
logging.basicConfig(level=debug and logging.DEBUG or logging.INFO)
engine = create_async_engine(database, echo=debug)
asyncio.run(IngestDebsec(cpe_product, dir)(engine))

def __init__(self, cpe_product: str, path: Path) -> None:
self.path = path

Expand Down Expand Up @@ -116,24 +150,4 @@ async def __call__(


if __name__ == '__main__':
import argparse
logging.basicConfig(level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument(
'cpe_product',
choices=sorted(DistCpeMapper.keys()),
help=f'CPE product used for data, supported: {" ".join(sorted(DistCpeMapper.keys()))}',
metavar='CPE_PRODUCT',
)
parser.add_argument(
'dir',
help='data directory out of https://salsa.debian.org/security-tracker-team/security-tracker',
metavar='DEBSEC',
type=Path,
)
args = parser.parse_args()
engine = create_async_engine(
"postgresql+asyncpg:///",
)
ingest = IngestDebsec(args.cpe_product, args.dir)
asyncio.run(ingest(engine))
IngestDebsec.run()
66 changes: 40 additions & 26 deletions src/glvd/cli/ingest_debsrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,51 @@
from ..database import Base, DistCpe, Debsrc
from ..data.debsrc import DebsrcFile
from ..data.dist_cpe import DistCpeMapper
from . import cli


logger = logging.getLogger(__name__)


class IngestDebsrc:
@staticmethod
@cli.register(
'ingest-debsrc',
arguments=[
cli.add_argument(
'cpe_product',
choices=sorted(DistCpeMapper.keys()),
help=f'CPE product used for data, supported: {" ".join(sorted(DistCpeMapper.keys()))}',
metavar='CPE_PRODUCT',
),
cli.add_argument(
'deb_codename',
help='codename of APT archive',
metavar='CODENAME',
),
cli.add_argument(
'file',
help='uncompressed Sources file',
metavar='SOURCES',
type=Path,
),
cli.add_argument(
'--database',
default='postgresql+asyncpg:///',
help='the database to use, must use asyncio compatible SQLAlchemy driver',
),
cli.add_argument(
'--debug',
action='store_true',
help='enable debug output',
),
]
)
def run(cpe_product: str, deb_codename: str, file: Path, database: str, debug: bool) -> None:
logging.basicConfig(level=debug and logging.DEBUG or logging.INFO)
engine = create_async_engine(database, echo=debug)
asyncio.run(IngestDebsrc(cpe_product, deb_codename, file)(engine))

def __init__(self, cpe_product: str, deb_codename: str, file: Path) -> None:
self.file = file

Expand Down Expand Up @@ -103,29 +142,4 @@ async def __call__(


if __name__ == '__main__':
import argparse
logging.basicConfig(level=logging.DEBUG)
parser = argparse.ArgumentParser()
parser.add_argument(
'cpe_product',
choices=sorted(DistCpeMapper.keys()),
help=f'CPE product used for data, supported: {" ".join(sorted(DistCpeMapper.keys()))}',
metavar='CPE_PRODUCT',
)
parser.add_argument(
'deb_codename',
help='codename of APT archive',
metavar='CODENAME',
)
parser.add_argument(
'file',
help='uncompressed Sources file',
metavar='SOURCES',
type=Path,
)
args = parser.parse_args()
engine = create_async_engine(
"postgresql+asyncpg:///",
)
ingest = IngestDebsrc(args.cpe_product, args.deb_codename, args.file)
asyncio.run(ingest(engine))
IngestDebsrc.run()
Loading

0 comments on commit fe11577

Please sign in to comment.