Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Audit mode based on Google OSV #58

Merged
merged 5 commits into from
Sep 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions it_depends/audit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from abc import ABC
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
from requests import post
from tqdm import tqdm
from typing import Dict, FrozenSet, Iterable, List, Union, Tuple

from .dependencies import Package, PackageRepository, Vulnerability

logger = logging.getLogger(__name__)


class OSVVulnerability(Vulnerability):
"""Represents a vulnerability from the OSV project"""

"""Additional keys available from the OSV Vulnerability db."""
EXTRA_KEYS = [
"published", "modified", "withdrawn", "related", "package",
"details", "affects", "affected", "references", "severity",
"database_specific", "ecosystem_specific"]

def __init__(self, osv_dict: Dict):
# Get the first available information as summary (N/A if none)
summary = osv_dict.get("summary", "") or osv_dict.get("details", "")\
or "N/A"
super().__init__(
osv_dict["id"], osv_dict.get("aliases", []), summary)

# Inherit all other attributes
for k in OSVVulnerability.EXTRA_KEYS:
setattr(self, k, osv_dict.get(k, None))

@classmethod
def from_osv_dict(cls, d: Dict):
return OSVVulnerability(d)


class VulnerabilityProvider(ABC):
"""Interface of a vulnerability provider."""
def query(self, pkg: Package) ->\
Iterable[Vulnerability]:
"""Queries the vulnerability provider for vulnerabilities in pkg"""
raise NotImplementedError()


class OSVProject(VulnerabilityProvider):
"""OSV project vulnerability provider"""
QUERY_URL = "https://api.osv.dev/v1/query"

def query(self, pkg: Package) ->\
Iterable[OSVVulnerability]:
"""Queries the OSV project for vulnerabilities in Package pkg"""
q = {"version": str(pkg.version), "package": {"name": pkg.name}}
r = post(OSVProject.QUERY_URL, json=q).json()
return map(OSVVulnerability.from_osv_dict, r.get("vulns", []))


def vulnerabilities(repo: PackageRepository, nworkers=None) -> \
PackageRepository:

def _get_vulninfo(pkg: Package) -> Tuple[Package, FrozenSet[Vulnerability]]:
"""Enrich a Package with vulnerability information"""
ret = OSVProject().query(pkg)
# Do not modify pkg here to ensure no concurrent
# modifications, instead return and let the main
# thread handle the updates.
return (pkg, frozenset({vuln: vuln for vuln in ret}))

with ThreadPoolExecutor(max_workers=nworkers) as executor, \
tqdm(desc="Checking for vulnerabilities", leave=False,
unit=" packages") as t:
futures = {executor.submit(_get_vulninfo, pkg): pkg for pkg in repo}
t.total = len(futures)

for future in as_completed(futures):
try:
t.update(1)
pkg, vulns = future.result()
except Exception as exc:
logger.error("Failed to retrieve vulnerability information. "
"Exception: {}".format(exc))
else:
pkg.update_vulnerabilities(vulns)

return repo
1 change: 1 addition & 0 deletions it_depends/cargo.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_dependencies(repo: SourceRepository, check_for_cargo: bool = True, cache
version=Version.coerce(package["version"]),
source="cargo",
dependencies=dependencies.values(),
vulnerabilities=(),
**kwargs
)

Expand Down
7 changes: 7 additions & 0 deletions it_depends/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from sqlalchemy.exc import OperationalError

from .audit import vulnerabilities
from .db import DEFAULT_DB_PATH, DBPackageCache
from .dependencies import Dependency, resolvers, resolver_by_name, resolve, SourceRepository
from .html import graph_to_html
Expand Down Expand Up @@ -49,6 +50,8 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
"RESOLVER_NAME:PACKAGE_NAME[@OPTIONAL_VERSION], where RESOLVER_NAME is a resolver listed "
"in `it-depends --list`. For example: \"pip:numpy\", \"apt:[email protected]\", or "
"\"npm:lodash@>=4.17.0\".")

parser.add_argument("--audit", "-a", action="store_true", help="Audit packages for known vulnerabilities using Google OSV")
parser.add_argument("--list", "-l", action="store_true", help="list available package resolver")
parser.add_argument("--database", "-db", type=str, nargs="?", default=DEFAULT_DB_PATH,
help="alternative path to load/store the database, or \":memory:\" to cache all results in "
Expand Down Expand Up @@ -137,6 +140,10 @@ def main(argv: Optional[Sequence[str]] = None) -> int:
sys.stderr.write(f"Try --list to check for available resolvers for {args.PATH_OR_NAME}\n")
sys.stderr.flush()

# TODO: Should the cache be updated instead????
if args.audit:
package_list = vulnerabilities(package_list)

if to_compare is not None:
to_compare_list = \
resolve(to_compare, cache=cache, depth_limit=args.depth_limit, max_workers=args.max_workers)
Expand Down
49 changes: 46 additions & 3 deletions it_depends/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,36 @@
logger = logging.getLogger(__name__)


class Vulnerability:
"""Represents a specific vulnerability"""
def __init__(self, id: str, aliases: Iterable[str], summary: str) -> None:
self.id = id
self.aliases = list(aliases)
self.summary = summary

def to_compact_str(self) -> str:
return f"{self.id} ({', '.join(self.aliases)})"

def to_obj(self) -> Dict[str, Union[str, List[str]]]:
return {
"id": self.id,
"aliases": self.aliases,
"summary": self.summary
}

def __eq__(self, other):
if issubclass(other.__class__, Vulnerability):
return self.id == other.id
return False

def __hash__(self):
return hash((self.id, ''.join(self.aliases), self.summary))

def __lt__(self, other):
if not issubclass(other.__class__, Vulnerability):
raise ValueError("Need a Vulnerability")
return self.id < other.id

class Dependency:
def __init__(self, package: str, source: Union[str, "DependencyResolver"],
semantic_version: SemanticVersion = SimpleSpec("*")):
Expand Down Expand Up @@ -96,6 +126,7 @@ def __init__(
version: Union[str, Version],
source: Union[str, "DependencyResolver"],
dependencies: Iterable[Dependency] = (),
vulnerabilities: Iterable[Vulnerability] = ()
):
if isinstance(version, str):
version = Version(version)
Expand All @@ -106,6 +137,8 @@ def __init__(
self.source: str = source.name
else:
self.source = source
self.vulnerabilities: FrozenSet[Vulnerability] = \
frozenset(vulnerabilities)

@property
def full_name(self) -> str:
Expand All @@ -115,6 +148,11 @@ def update_dependencies(self, dependencies: FrozenSet[Dependency]):
self.dependencies = self.dependencies.union(dependencies)
return self

def update_vulnerabilities(self, vulnerabilities:
FrozenSet[Vulnerability]):
self.vulnerabilities = self.vulnerabilities.union(vulnerabilities)
return self

@property
def resolver(self):
"""
Expand Down Expand Up @@ -165,7 +203,8 @@ def to_obj(self):
"version": str(self.version),
"dependencies": {
f"{dep.source}:{dep.package}": str(dep.semantic_version) for dep in self.dependencies
}
},
"vulnerabilities": [vuln.to_obj() for vuln in self.vulnerabilities]
}
return ret # type: ignore

Expand Down Expand Up @@ -231,8 +270,10 @@ def __init__(
source_repo: SourceRepository,
source: str,
dependencies: Iterable[Dependency] = (),
vulnerabilities: Iterable[Vulnerability] = ()
):
super().__init__(name=name, version=version, dependencies=dependencies, source=source)
super().__init__(name=name, version=version, dependencies=dependencies,
source=source, vulnerabilities=vulnerabilities)
self.source_repo: SourceRepository = source_repo

def __str__(self):
Expand Down Expand Up @@ -445,6 +486,7 @@ def package_to_dict(package: Package):
f"{dep.source}:{dep.package}": str(dep.semantic_version)
for dep in package.dependencies
},
"vulnerabilities": [v.to_compact_str() for v in package.vulnerabilities],
"source": package.source
}
if isinstance(package, SourcePackage):
Expand Down Expand Up @@ -486,7 +528,8 @@ def add_package(pkg: Package) -> str:
if pkg not in package_ids:
pkg_id = f"package{len(package_ids)}"
package_ids[pkg] = pkg_id
dot.node(pkg_id, label=str(pkg), shape="rectangle")
shape = "triangle" if pkg.vulnerabilities else "rectangle"
dot.node(pkg_id, label=str(pkg), shape=shape)
return pkg_id
else:
return package_ids[pkg]
Expand Down
2 changes: 2 additions & 0 deletions it_depends/html.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def graph_to_html(
"color": "red",
"borderWidth": 4,
})
if package.vulnerabilities:
nodes[-1].update({"color": "red"})
if graph.source_packages:
nodes[-1]["level"] = max(graph.shortest_path_from_root(package), 0)
for pkg1, pkg2, *_ in graph.out_edges(package): # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"tqdm>=4.48.0"
],
extras_require={
"dev": ["flake8", "pytest", "twine", "mypy>=0.812"]
"dev": ["flake8", "pytest", "twine", "mypy>=0.812", "types-setuptools", "types-requests"]
},
entry_points={
"console_scripts": [
Expand Down
139 changes: 139 additions & 0 deletions test/test_audit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import threading
from it_depends.dependencies import InMemoryPackageCache, Package, Vulnerability
from it_depends import audit

import logging
import random
import string
import time
from unittest import TestCase
from unittest.mock import Mock, patch


logger = logging.getLogger(__name__)


def _rand_str(n):
"""Returns a random string of length n (upper, lower and digits)"""
return ''.join(random.choice(string.ascii_lowercase +
string.ascii_uppercase + string.digits)
for i in range(n))


def _version_str():
"""Returns a typical version string (x.y.z)"""
return f"{random.randint(0, 30)}.{random.randint(0,5)}." \
f"{random.randint(0, 9)}"


def _random_package():
"""Returns a package of random name, version and source"""
return Package(_rand_str(10), _version_str(), _rand_str(5))


def _random_packages(num_packages):
"""Returns PackacgeCache populated with num_package random Packages"""
packages = InMemoryPackageCache()
list(map(packages.add, [_random_package() for i in range(num_packages)]))
return packages


def _random_vulnerability():
"""Create a random vulnerability"""
return Vulnerability(_rand_str(10),
[_rand_str(3) for i in range(random.randint(0, 7)) if
random.randint(0, 100) < 90],
_rand_str(random.randint(0, 10)))


def _random_vulnerabilities(max_count):
"""Return up to max_count vulnerabilities"""
return [_random_vulnerability() for x in range(random.randint(0, max_count))]


class TestAudit(TestCase):
def setUp(self):
# To be able to repeat a failing test the seed for random is logged
seed = int(time.time())
random.seed(seed)
logger.warning(f"Using seed: {seed}")

@patch('it_depends.audit.post')
def test_nopackages_no_requests(self, mock_post):
packages = _random_packages(0)
ret = audit.vulnerabilities(packages)
self.assertEqual(ret, packages)
mock_post.assert_not_called()

@patch('it_depends.audit.post')
def test_valid_limited_info_response(self, mock_post):
"""Ensures that a single vuln with the minimum amount of info we require works"""
packages = _random_packages(1)
mock_post().json.return_value = {"vulns": [{"id": "123"}]}
ret = audit.vulnerabilities(packages)

pkg = next(p for p in ret)
vuln = next(v for v in pkg.vulnerabilities) # Assume one vulnerability
self.assertEqual(vuln.id, "123")
self.assertEqual(len(vuln.aliases), 0)
self.assertEqual(vuln.summary, "N/A")

@patch('it_depends.audit.post')
def test_no_vulns_can_be_handled(self, mock_post):
"""No vulnerability info can still be handled"""
packages = _random_packages(1)
mock_post().json.return_value = {}
ret = audit.vulnerabilities(packages)
self.assertTrue(all(map(lambda p: len(p.vulnerabilities) == 0, ret)))

@patch('it_depends.audit.post')
def test_handles_ten_thousand_requests(self, mock_post):
"""Constructs ten thousand random packages and maps random vulnerabilities to the packages.
Ensures that the vulnerability information received from OSV is reflected in the Packages"""

# Create 10k random packages (name, version, source)
packages = _random_packages(10000)

# For each of the packages map 0 or more vulnerabilities
package_vuln = {(pkg.name, str(pkg.version)): _random_vulnerabilities(10) for pkg in packages}

# Mocks the json-request to OSV, returns whatever info is in the package_vuln-map
def _osv_response(_, json):
m = Mock()
key = (json["package"]["name"], json["version"])
if key in package_vuln:
m.json.return_value = {"vulns": list(map(lambda x: x.to_obj(), package_vuln[key]))}
else:
m.json.return_value = {}
return m

mock_post.side_effect = _osv_response

# Query all packages for vulnerabilities, ensure that each package received vulnerabilitiy
# info as stated in the package_vuln-map created earlier.
for pkg in audit.vulnerabilities(packages):
pkgvuln = sorted(pkg.vulnerabilities)
expectedvuln = sorted(package_vuln[(pkg.name, str(pkg.version))])

self.assertListEqual(pkgvuln, expectedvuln)

@patch('it_depends.audit.post')
def test_exceptions_are_logged_and_isolated(self, mock_post):
"""Ensure that if exceptions happen during vulnerability querying they do not kill execution.
They shall still be logged."""
packages = _random_packages(100)
lock = threading.Lock()
counter = 0

def _osv_response(_, json):
nonlocal counter
m = Mock()
m.json.return_value = {}
with lock:
counter += 1
if counter % 2 == 0:
raise Exception("Ouch.")
return m
mock_post.side_effect = _osv_response

self.assertEqual(len(audit.vulnerabilities(packages)), 100)
Loading