diff --git a/examples/hunting_notifications_to_network_infrastructure.py b/examples/hunting_notifications_to_network_infrastructure.py index f40c4b5..bc43d02 100644 --- a/examples/hunting_notifications_to_network_infrastructure.py +++ b/examples/hunting_notifications_to_network_infrastructure.py @@ -98,21 +98,27 @@ async def get_network_infrastructure(self): contacted_domains = relationships["contacted_domains"]["data"] contacted_ips = relationships["contacted_ips"]["data"] contacted_urls = relationships["contacted_urls"]["data"] - await self.queue.put({ - "contacted_addresses": contacted_domains, - "type": "domains", - "file": file_hash, - }) - await self.queue.put({ - "contacted_addresses": contacted_ips, - "type": "ips", - "file": file_hash, - }) - await self.queue.put({ - "contacted_addresses": contacted_urls, - "type": "urls", - "file": file_hash, - }) + await self.queue.put( + { + "contacted_addresses": contacted_domains, + "type": "domains", + "file": file_hash, + } + ) + await self.queue.put( + { + "contacted_addresses": contacted_ips, + "type": "ips", + "file": file_hash, + } + ) + await self.queue.put( + { + "contacted_addresses": contacted_urls, + "type": "urls", + "file": file_hash, + } + ) self.networking_infrastructure[file_hash]["domains"] = contacted_domains self.networking_infrastructure[file_hash]["ips"] = contacted_ips self.networking_infrastructure[file_hash]["urls"] = contacted_urls diff --git a/examples/intelligence_search_to_network_infrastructure.py b/examples/intelligence_search_to_network_infrastructure.py index 36a4156..bd8d6ef 100644 --- a/examples/intelligence_search_to_network_infrastructure.py +++ b/examples/intelligence_search_to_network_infrastructure.py @@ -94,21 +94,27 @@ async def get_network(self): contacted_urls = relationships["contacted_urls"]["data"] contacted_ips = relationships["contacted_ips"]["data"] - await self.queue.put({ - "contacted_addresses": contacted_domains, - "type": "domains", - "file": checksum, - }) - await self.queue.put({ - "contacted_addresses": contacted_ips, - "type": "ips", - "file": checksum, - }) - await self.queue.put({ - "contacted_addresses": contacted_urls, - "type": "urls", - "file": checksum, - }) + await self.queue.put( + { + "contacted_addresses": contacted_domains, + "type": "domains", + "file": checksum, + } + ) + await self.queue.put( + { + "contacted_addresses": contacted_ips, + "type": "ips", + "file": checksum, + } + ) + await self.queue.put( + { + "contacted_addresses": contacted_urls, + "type": "urls", + "file": checksum, + } + ) self.networking_infrastructure[checksum]["domains"] = contacted_domains self.networking_infrastructure[checksum]["ips"] = contacted_ips diff --git a/examples/livehunt_network_watch.py b/examples/livehunt_network_watch.py index 610ab14..cf3df34 100644 --- a/examples/livehunt_network_watch.py +++ b/examples/livehunt_network_watch.py @@ -40,8 +40,8 @@ RULESET_LINK = "https://www.virustotal.com/yara-editor/livehunt/" EMPTY_DOMAIN_LIST_MSG = ( - "* Empty domain list, use --add-domain domain.tld or bulk operations to" - " register them" + "* Empty domain list, use --add-domain domain.tld or bulk operations to" + " register them" ) @@ -247,8 +247,9 @@ async def main(): return rulesets = await get_rulesets() - if (not rulesets and - not (args.add_domain or args.bulk_append or args.bulk_replace)): + if not rulesets and not ( + args.add_domain or args.bulk_append or args.bulk_replace + ): print(EMPTY_DOMAIN_LIST_MSG) sys.exit(1) diff --git a/examples/private_scan.py b/examples/private_scan.py index 1b350b1..b16954f 100644 --- a/examples/private_scan.py +++ b/examples/private_scan.py @@ -13,77 +13,68 @@ console = Console() + async def scan_file_private( - api_key: str, - file_path: Path, - wait: bool = False + api_key: str, file_path: Path, wait: bool = False ) -> None: - """ - Scan a file privately on VirusTotal. - - Args: - api_key: VirusTotal API key - file_path: Path to file to scan - wait: Wait for scan completion - """ - async with vt.Client(api_key) as client: - try: - with Progress() as progress: - task = progress.add_task( - "Scanning file...", - total=None if wait else 1 - ) - - analysis = await client.scan_file_private_async( - str(file_path), - wait_for_completion=wait - ) - - progress.update(task, advance=1) - - console.print("\n[green]Scan submitted successfully[/green]") - console.print(f"Analysis ID: {analysis.id}") - - if wait: - console.print(f"\nScan Status: {analysis.status}") - if hasattr(analysis, 'stats'): - console.print("Detection Stats:") - for k, v in analysis.stats.items(): - console.print(f" {k}: {v}") - - except vt.error.APIError as e: - console.print(f"[red]API Error: {e}[/red]") - except Exception as e: - console.print(f"[red]Error: {e}[/red]") + """ + Scan a file privately on VirusTotal. + + Args: + api_key: VirusTotal API key + file_path: Path to file to scan + wait: Wait for scan completion + """ + async with vt.Client(api_key) as client: + try: + with Progress() as progress: + task = progress.add_task("Scanning file...", total=None if wait else 1) + + analysis = await client.scan_file_private_async( + str(file_path), wait_for_completion=wait + ) + + progress.update(task, advance=1) + + console.print("\n[green]Scan submitted successfully[/green]") + console.print(f"Analysis ID: {analysis.id}") + + if wait: + console.print(f"\nScan Status: {analysis.status}") + if hasattr(analysis, "stats"): + console.print("Detection Stats:") + for k, v in analysis.stats.items(): + console.print(f" {k}: {v}") + + except vt.error.APIError as e: + console.print(f"[red]API Error: {e}[/red]") + except Exception as e: + console.print(f"[red]Error: {e}[/red]") + def main(): - parser = argparse.ArgumentParser( - description="Scan file privately using VirusTotal API" - ) - parser.add_argument("--apikey", help="VirusTotal API key") - parser.add_argument("--file_path", help="Path to file to scan") - parser.add_argument( - "--wait", - action="store_true", - help="Wait for scan completion" - ) - - args = parser.parse_args() - file_path = Path(args.file_path) - - if not file_path.exists(): - console.print(f"[red]Error: File {file_path} not found[/red]") - sys.exit(1) - - if not file_path.is_file(): - console.print(f"[red]Error: {file_path} is not a file[/red]") - sys.exit(1) - - asyncio.run(scan_file_private( - args.apikey, - file_path, - args.wait - )) + parser = argparse.ArgumentParser( + description="Scan file privately using VirusTotal API" + ) + parser.add_argument("--apikey", help="VirusTotal API key") + parser.add_argument("--file_path", help="Path to file to scan") + parser.add_argument( + "--wait", action="store_true", help="Wait for scan completion" + ) + + args = parser.parse_args() + file_path = Path(args.file_path) + + if not file_path.exists(): + console.print(f"[red]Error: File {file_path} not found[/red]") + sys.exit(1) + + if not file_path.is_file(): + console.print(f"[red]Error: {file_path} is not a file[/red]") + sys.exit(1) + + asyncio.run(scan_file_private(args.apikey, file_path, args.wait)) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/retrohunt_to_network_infrastructure.py b/examples/retrohunt_to_network_infrastructure.py index 85dd846..26867f4 100644 --- a/examples/retrohunt_to_network_infrastructure.py +++ b/examples/retrohunt_to_network_infrastructure.py @@ -87,19 +87,23 @@ async def get_network_infrastructure(self, file_obj): contacted_domains = relationships["contacted_domains"]["data"] contacted_ips = relationships["contacted_ips"]["data"] contacted_urls = relationships["contacted_urls"]["data"] - await self.networking_queue.put({ - "contacted_addresses": contacted_domains, - "type": "domains", - "file": file_hash, - }) + await self.networking_queue.put( + { + "contacted_addresses": contacted_domains, + "type": "domains", + "file": file_hash, + } + ) await self.networking_queue.put( {"contacted_addresses": contacted_ips, "type": "ips", "file": file_hash} ) - await self.networking_queue.put({ - "contacted_addresses": contacted_urls, - "type": "urls", - "file": file_hash, - }) + await self.networking_queue.put( + { + "contacted_addresses": contacted_urls, + "type": "urls", + "file": file_hash, + } + ) self.networking_infrastructure[file_hash]["domains"] = contacted_domains self.networking_infrastructure[file_hash]["ips"] = contacted_ips self.networking_infrastructure[file_hash]["urls"] = contacted_urls diff --git a/tests/test_client.py b/tests/test_client.py index 057c3f6..29aacad 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -36,15 +36,19 @@ def new_client(httpserver, unused_apikey=""): def test_object_from_dict(): - obj = Object.from_dict({ - "type": "dummy_type", - "id": "dummy_id", - "attributes": { - "attr1": "foo", - "attr2": 1, - }, - "relationships": {"foos": {"data": [{"type": "foo", "id": "foo_id"}]}}, - }) + obj = Object.from_dict( + { + "type": "dummy_type", + "id": "dummy_id", + "attributes": { + "attr1": "foo", + "attr2": 1, + }, + "relationships": { + "foos": {"data": [{"type": "foo", "id": "foo_id"}]} + }, + } + ) assert obj.id == "dummy_id" assert obj.type == "dummy_type" @@ -83,16 +87,18 @@ def test_object_pickle(): def test_object_to_dict(): - obj = Object.from_dict({ - "type": "dummy_type", - "id": "dummy_id", - "attributes": { - "attr1": "foo", - "attr2": 1, - "attr3": {"subattr1": "bar"}, - "attr4": {"subattr1": "baz"}, - }, - }) + obj = Object.from_dict( + { + "type": "dummy_type", + "id": "dummy_id", + "attributes": { + "attr1": "foo", + "attr2": 1, + "attr3": {"subattr1": "bar"}, + "attr4": {"subattr1": "baz"}, + }, + } + ) obj.set_data("data_key", {"some": "value"}) @@ -195,7 +201,7 @@ def test_patch_object(httpserver): "attributes": { "foo": 2, }, - "context_attributes": {"a": "b"} + "context_attributes": {"a": "b"}, } } ) @@ -233,7 +239,9 @@ def test_post_object(httpserver): def test_delete(httpserver): httpserver.expect_request( - "/api/v3/foo", method="DELETE", headers={"X-Apikey": "dummy_api_key"}, + "/api/v3/foo", + method="DELETE", + headers={"X-Apikey": "dummy_api_key"}, json={"hello": "world"}, ).respond_with_json({"data": "dummy_data"}) @@ -250,11 +258,13 @@ def test_iterator(httpserver): headers={"X-Apikey": "dummy_api_key"}, ).respond_with_json( { - "data": [{ - "id": "dummy_id_1", - "type": "dummy_type", - "attributes": {"order": 0}, - }] + "data": [ + { + "id": "dummy_id_1", + "type": "dummy_type", + "attributes": {"order": 0}, + } + ] } ) @@ -490,16 +500,18 @@ def test_wsgi_app(httpserver, monkeypatch): app = wsgi_app.app app.config.update({"TESTING": True}) client = app.test_client() - expected_response = {"data": { - "id": "google.com", - "type": "domain", - "attributes": {"foo": "foo"}, - }} - + expected_response = { + "data": { + "id": "google.com", + "type": "domain", + "attributes": {"foo": "foo"}, + } + } httpserver.expect_request( - "/api/v3/domains/google.com", method="GET", - headers={"X-Apikey": "dummy_api_key"} + "/api/v3/domains/google.com", + method="GET", + headers={"X-Apikey": "dummy_api_key"}, ).respond_with_json(expected_response) monkeypatch.setattr( "tests.wsgi_app.vt.Client", functools.partial(new_client, httpserver) @@ -508,79 +520,72 @@ def test_wsgi_app(httpserver, monkeypatch): assert response.status_code == 200 assert response.json == expected_response + @pytest.fixture def private_scan_mocks(httpserver): - """Fixture for mocking private scan API calls.""" - upload_url = f"http://{httpserver.host}:{httpserver.port}/upload" - - # Mock private upload URL request - httpserver.expect_request( - "/api/v3/private/files/upload_url", - method="GET" - ).respond_with_json({ - "data": upload_url - }) - - # Mock file upload response - httpserver.expect_request( - "/upload", - method="POST" - ).respond_with_json({ - "data": { - "id": "dummy_scan_id", - "type": "private_analysis", - "links": { - "self": "dummy_link" - }, - "attributes": { - "status": "queued", - } - } - }) - - # Add mock for analysis status endpoint - httpserver.expect_request( - "/api/v3/analyses/dummy_scan_id", - method="GET" - ).respond_with_json({ - "data": { - "id": "dummy_scan_id", - "type": "private_analysis", - "links": { - "self": "dummy_link" - }, - "attributes": { - "status": "completed", - "stats": { - "malicious": 0, - "suspicious": 0 - } - } - } - }) - - return upload_url + """Fixture for mocking private scan API calls.""" + upload_url = f"http://{httpserver.host}:{httpserver.port}/upload" + + # Mock private upload URL request + httpserver.expect_request( + "/api/v3/private/files/upload_url", method="GET" + ).respond_with_json({"data": upload_url}) + + # Mock file upload response + httpserver.expect_request("/upload", method="POST").respond_with_json( + { + "data": { + "id": "dummy_scan_id", + "type": "private_analysis", + "links": {"self": "dummy_link"}, + "attributes": { + "status": "queued", + }, + } + } + ) + + # Add mock for analysis status endpoint + httpserver.expect_request( + "/api/v3/analyses/dummy_scan_id", method="GET" + ).respond_with_json( + { + "data": { + "id": "dummy_scan_id", + "type": "private_analysis", + "links": {"self": "dummy_link"}, + "attributes": { + "status": "completed", + "stats": {"malicious": 0, "suspicious": 0}, + }, + } + } + ) + + return upload_url + def verify_analysis(analysis, status="queued"): - """Helper to verify analysis response.""" - assert analysis.id == "dummy_scan_id" - assert analysis.type == "private_analysis" - assert getattr(analysis, "status") == status + """Helper to verify analysis response.""" + assert analysis.id == "dummy_scan_id" + assert analysis.type == "private_analysis" + assert getattr(analysis, "status") == status + def test_scan_file_private(httpserver, private_scan_mocks): - """Test synchronous private file scanning.""" - with new_client(httpserver) as client: - with io.StringIO("test file content") as f: - analysis = client.scan_file_private(f) - verify_analysis(analysis) + """Test synchronous private file scanning.""" + with new_client(httpserver) as client: + with io.StringIO("test file content") as f: + analysis = client.scan_file_private(f) + verify_analysis(analysis) + @pytest.mark.asyncio async def test_scan_file_private_async(httpserver, private_scan_mocks): - """Test asynchronous private file scanning.""" - async with new_client(httpserver) as client: - with io.StringIO("test file content") as f: - analysis = await client.scan_file_private_async( - f, - wait_for_completion=True - ) - verify_analysis(analysis, status="completed") \ No newline at end of file + """Test asynchronous private file scanning.""" + async with new_client(httpserver) as client: + with io.StringIO("test file content") as f: + analysis = await client.scan_file_private_async( + f, wait_for_completion=True + ) + verify_analysis(analysis, status="completed") diff --git a/tests/test_feed.py b/tests/test_feed.py index 037a27a..3e2fcf7 100644 --- a/tests/test_feed.py +++ b/tests/test_feed.py @@ -179,7 +179,7 @@ def test_tolerance(httpserver, tolerance): ], ) @pytest.mark.usefixtures("feed_response") -def test_cursor(httpserver, test_iters, expected_cursor): +def test_cursor(httpserver, test_iters, expected_cursor): """Tests feed's cursor.""" with new_client(httpserver) as client: feed = client.feed(FeedType.FILES, cursor="200102030405") diff --git a/tests/test_iterator.py b/tests/test_iterator.py index 6f06178..225593b 100644 --- a/tests/test_iterator.py +++ b/tests/test_iterator.py @@ -32,26 +32,28 @@ def fixture_iterator_response(httpserver): "/api/v3/dummy_collection/foo", method="GET", headers={"X-Apikey": "dummy_api_key"}, - ).respond_with_json({ - "data": [ - { - "id": "dummy_id_1", - "type": "dummy_type", - "attributes": {"order": 0}, - }, - { - "id": "dummy_id_2", - "type": "dummy_type", - "attributes": {"order": 0}, - }, - { - "id": "dummy_id_3", - "type": "dummy_type", - "attributes": {"order": 0}, - }, - ], - "meta": {"cursor": "3", "total_hits": 200}, - }) + ).respond_with_json( + { + "data": [ + { + "id": "dummy_id_1", + "type": "dummy_type", + "attributes": {"order": 0}, + }, + { + "id": "dummy_id_2", + "type": "dummy_type", + "attributes": {"order": 0}, + }, + { + "id": "dummy_id_3", + "type": "dummy_type", + "attributes": {"order": 0}, + }, + ], + "meta": {"cursor": "3", "total_hits": 200}, + } + ) httpserver.expect_ordered_request( "/api/v3/dummy_collection/foo", method="GET", diff --git a/vt/client.py b/vt/client.py index d43e9ef..044ea8c 100644 --- a/vt/client.py +++ b/vt/client.py @@ -218,7 +218,7 @@ def __init__( proxy: typing.Optional[str] = None, headers: typing.Optional[typing.Dict] = None, verify_ssl: bool = True, - connector: aiohttp.BaseConnector = None + connector: aiohttp.BaseConnector = None, ): """Initialize the client with the provided API key.""" @@ -252,7 +252,7 @@ def __init__( ssl=self._verify_ssl, loop=event_loop ) - def _full_url(self, path:str, *args: typing.Any) -> str: + def _full_url(self, path: str, *args: typing.Any) -> str: try: path = path.format(*args) except IndexError as exc: @@ -281,7 +281,7 @@ def _get_session(self) -> aiohttp.ClientSession: headers=headers, trust_env=self._trust_env, timeout=aiohttp.ClientTimeout(total=self._timeout), - json_serialize=functools.partial(json.dumps, cls=UserDictJsonEncoder) + json_serialize=functools.partial(json.dumps, cls=UserDictJsonEncoder), ) return self._session @@ -339,11 +339,11 @@ def close(self) -> None: return make_sync(self.close_async()) def delete( - self, - path: str, - *path_args: typing.Any, - data: typing.Optional[typing.Union[str, bytes]] = None, - json_data: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + data: typing.Optional[typing.Union[str, bytes]] = None, + json_data: typing.Optional[typing.Dict] = None, ) -> ClientResponse: """Sends a DELETE request to a given API endpoint. @@ -366,7 +366,7 @@ async def delete_async( path: str, *path_args: typing.Any, data: typing.Optional[typing.Union[str, bytes]] = None, - json_data: typing.Optional[typing.Dict] = None + json_data: typing.Optional[typing.Dict] = None, ) -> ClientResponse: """Like :func:`delete` but returns a coroutine.""" return ClientResponse( @@ -374,7 +374,7 @@ async def delete_async( self._full_url(path, *path_args), data=data, json=json_data, - proxy=self._proxy + proxy=self._proxy, ) ) @@ -392,9 +392,7 @@ def download_file(self, file_hash: str, file: typing.BinaryIO) -> None: return make_sync(self.download_file_async(file_hash, file)) async def __download_async( - self, - endpoint: str, - file: typing.BinaryIO + self, endpoint: str, file: typing.BinaryIO ) -> None: """Downloads a file and writes it to file. @@ -412,19 +410,17 @@ async def __download_async( file.write(chunk) async def download_file_async( - self, - file_hash : str, - file: typing.BinaryIO + self, file_hash: str, file: typing.BinaryIO ) -> None: """Like :func:`download_file` but returns a coroutine.""" await self.__download_async(f"/files/{file_hash}/download", file) def download_zip_files( - self, - hashes: typing.List[str], - zipfile: typing.BinaryIO, - password: typing.Optional[str] = None, - sleep_time: int = 20 + self, + hashes: typing.List[str], + zipfile: typing.BinaryIO, + password: typing.Optional[str] = None, + sleep_time: int = 20, ) -> None: """Creates a bundle zip bundle containing one or multiple files. @@ -442,11 +438,11 @@ def download_zip_files( ) async def download_zip_files_async( - self, - hashes: typing.List[str], - zipfile: typing.BinaryIO, - password: typing.Optional[str] = None, - sleep_time: int = 20 + self, + hashes: typing.List[str], + zipfile: typing.BinaryIO, + password: typing.Optional[str] = None, + sleep_time: int = 20, ) -> None: data = {"hashes": hashes} if password: @@ -485,9 +481,7 @@ async def download_zip_files_async( ) def feed( - self, - feed_type: FeedType, - cursor: typing.Optional[str] = None + self, feed_type: FeedType, cursor: typing.Optional[str] = None ) -> Feed: """Returns an iterator for a VirusTotal feed. @@ -506,10 +500,10 @@ def feed( return Feed(self, feed_type, cursor=cursor) def get( - self, - path: str, - *path_args: typing.Any, - params: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + params: typing.Optional[typing.Dict] = None, ) -> ClientResponse: """Sends a GET request to a given API endpoint. @@ -528,10 +522,10 @@ def get( return make_sync(self.get_async(path, *path_args, params=params)) async def get_async( - self, - path: str, - *path_args: typing.Any, - params: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + params: typing.Optional[typing.Dict] = None, ) -> ClientResponse: """Like :func:`get` but returns a coroutine.""" return ClientResponse( @@ -541,10 +535,10 @@ async def get_async( ) def get_data( - self, - path: str, - *path_args: typing.Any, - params: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + params: typing.Optional[typing.Dict] = None, ) -> typing.Any: """Sends a GET request to a given API endpoint and returns response's data. @@ -574,15 +568,14 @@ async def get_data_async( self, path: str, *path_args: typing.Any, - params: typing.Optional[typing.Dict] = None - ) -> typing.Any: + params: typing.Optional[typing.Dict] = None, + ) -> typing.Any: """Like :func:`get_data` but returns a coroutine.""" json_response = await self.get_json_async(path, *path_args, params=params) return self._extract_data_from_json(json_response) async def get_error_async( - self, - response: ClientResponse + self, response: ClientResponse ) -> typing.Optional[APIError]: """Given a :class:`ClientResponse` returns a :class:`APIError` @@ -605,10 +598,10 @@ async def get_error_async( return APIError("ServerError", await response.text_async()) def get_json( - self, - path: str , - *path_args: typing.Any, - params: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + params: typing.Optional[typing.Dict] = None, ) -> typing.Dict: """Sends a GET request to a given API endpoint and parses the response. @@ -627,20 +620,20 @@ def get_json( return make_sync(self.get_json_async(path, *path_args, params=params)) async def get_json_async( - self, - path: str, - *path_args: typing.Any, - params: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + params: typing.Optional[typing.Dict] = None, ) -> typing.Dict: """Like :func:`get_json` but returns a coroutine.""" response = await self.get_async(path, *path_args, params=params) return await self._response_to_json(response) def get_object( - self, - path: str, - *path_args: typing.Any, - params: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + params: typing.Optional[typing.Dict] = None, ) -> Object: """Sends a GET request to a given API endpoint and returns an object. @@ -661,21 +654,21 @@ def get_object( return make_sync(self.get_object_async(path, *path_args, params=params)) async def get_object_async( - self, - path: str, - *path_args: typing.Any, - params: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + params: typing.Optional[typing.Dict] = None, ) -> Object: """Like :func:`get_object` but returns a coroutine.""" response = await self.get_async(path, *path_args, params=params) return await self._response_to_object(response) def patch( - self, - path: str, - *path_args: typing.Any, - data: typing.Optional[typing.Union[str, bytes]] = None, - json_data: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + data: typing.Optional[typing.Union[str, bytes]] = None, + json_data: typing.Optional[typing.Dict] = None, ) -> ClientResponse: """Sends a PATCH request to a given API endpoint. @@ -696,11 +689,11 @@ def patch( return make_sync(self.patch_async(path, *path_args, data, json_data)) async def patch_async( - self, - path: str, - *path_args: typing.Any, - data: typing.Optional[typing.Union[str, bytes]] = None, - json_data: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + data: typing.Optional[typing.Union[str, bytes]] = None, + json_data: typing.Optional[typing.Dict] = None, ) -> ClientResponse: """Like :func:`patch` but returns a coroutine.""" return ClientResponse( @@ -713,10 +706,7 @@ async def patch_async( ) def patch_object( - self, - path: str, - *path_args: typing.Any, - obj: Object + self, path: str, *path_args: typing.Any, obj: Object ) -> Object: """Sends a PATCH request for modifying an object. @@ -735,10 +725,7 @@ def patch_object( return make_sync(self.patch_object_async(path, *path_args, obj=obj)) async def patch_object_async( - self, - path: str, - *path_args: typing.Any, - obj: Object + self, path: str, *path_args: typing.Any, obj: Object ) -> Object: """Like :func:`patch_object` but returns a coroutine.""" data = {"data": obj.to_dict(modified_attributes_only=True)} @@ -747,11 +734,11 @@ async def patch_object_async( return await self._response_to_object(response) def post( - self, - path: str, - *path_args: typing.Any, - data: typing.Optional[typing.Union[str, bytes]] = None, - json_data: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + data: typing.Optional[typing.Union[str, bytes]] = None, + json_data: typing.Optional[typing.Dict] = None, ) -> ClientResponse: """Sends a POST request to a given API endpoint. @@ -774,11 +761,11 @@ def post( ) async def post_async( - self, - path: str, - *path_args: typing.Any, - data: typing.Optional[typing.Union[str, bytes]] = None, - json_data: typing.Optional[typing.Dict] = None + self, + path: str, + *path_args: typing.Any, + data: typing.Optional[typing.Union[str, bytes]] = None, + json_data: typing.Optional[typing.Dict] = None, ) -> ClientResponse: """Like :func:`post` but returns a coroutine.""" return ClientResponse( @@ -791,10 +778,7 @@ async def post_async( ) def post_object( - self, - path: str, - *path_args: typing.Any, - obj: Object + self, path: str, *path_args: typing.Any, obj: Object ) -> Object: """Sends a POST request for creating an object. @@ -813,10 +797,7 @@ def post_object( return make_sync(self.post_object_async(path, *path_args, obj=obj)) async def post_object_async( - self, - path: str, - *path_args: typing.Any, - obj: Object + self, path: str, *path_args: typing.Any, obj: Object ) -> Object: """Like :func:`post_object` but returns a coroutine.""" data = {"data": obj.to_dict()} @@ -825,13 +806,13 @@ async def post_object_async( return await self._response_to_object(response) def iterator( - self, - path: str, - *path_args: typing.Any, - params: typing.Optional[typing.Dict] = None, - cursor: typing.Optional[str] = None, - limit: typing.Optional[int] = None, - batch_size: int = 0 + self, + path: str, + *path_args: typing.Any, + params: typing.Optional[typing.Dict] = None, + cursor: typing.Optional[str] = None, + limit: typing.Optional[int] = None, + batch_size: int = 0, ) -> Iterator: """Returns an iterator for the collection specified by the given path. @@ -868,9 +849,7 @@ def iterator( ) def scan_file( - self, - file: typing.BinaryIO, - wait_for_completion: bool = False + self, file: typing.BinaryIO, wait_for_completion: bool = False ) -> Object: """Scans a file. @@ -886,9 +865,7 @@ def scan_file( ) async def scan_file_async( - self, - file: typing.BinaryIO, - wait_for_completion: bool = False + self, file: typing.BinaryIO, wait_for_completion: bool = False ) -> Object: """Like :func:`scan_file` but returns a coroutine.""" @@ -944,9 +921,7 @@ def scan_url(self, url: str, wait_for_completion: bool = False) -> Object: ) async def scan_url_async( - self, - url: str, - wait_for_completion: bool = False + self, url: str, wait_for_completion: bool = False ) -> Object: """Like :func:`scan_url` but returns a coroutine.""" form_data = aiohttp.FormData() @@ -977,51 +952,48 @@ async def wait_for_analysis_completion(self, analysis: Object) -> Object: return await self._wait_for_analysis_completion(analysis) def scan_file_private( - self, + self, file: typing.Union[typing.BinaryIO, str], - wait_for_completion: bool = False + wait_for_completion: bool = False, ) -> Object: - """Scan file privately. - - Args: - file: File to scan (path string or file object) - wait_for_completion: Wait for completion - - Returns: - Object: Analysis object with scan results - """ - return make_sync( - self.scan_file_private_async(file, wait_for_completion) - ) + """Scan file privately. + + Args: + file: File to scan (path string or file object) + wait_for_completion: Wait for completion + + Returns: + Object: Analysis object with scan results + """ + return make_sync(self.scan_file_private_async(file, wait_for_completion)) async def scan_file_private_async( self, file: typing.Union[typing.BinaryIO, str], - wait_for_completion: bool = False + wait_for_completion: bool = False, ) -> Object: - """Async version of scan_file_private""" - - # Handle string path - if isinstance(file, str): - async with aiofiles.open(file, 'rb') as f: - file_content = io.BytesIO(await f.read()) - file_content.name = os.path.basename(file) - return await self.scan_file_private_async( - file_content, - wait_for_completion=wait_for_completion - ) + """Async version of scan_file_private""" + + # Handle string path + if isinstance(file, str): + async with aiofiles.open(file, "rb") as f: + file_content = io.BytesIO(await f.read()) + file_content.name = os.path.basename(file) + return await self.scan_file_private_async( + file_content, wait_for_completion=wait_for_completion + ) - # Create form data for private scan - form = aiohttp.FormData() - form.add_field('file', file) + # Create form data for private scan + form = aiohttp.FormData() + form.add_field("file", file) - # Get private upload URL and submit - upload_url = await self.get_data_async("/private/files/upload_url") - response = await self.post_async(upload_url, data=form) + # Get private upload URL and submit + upload_url = await self.get_data_async("/private/files/upload_url") + response = await self.post_async(upload_url, data=form) - analysis = await self._response_to_object(response) + analysis = await self._response_to_object(response) - if wait_for_completion: - analysis = await self._wait_for_analysis_completion(analysis) + if wait_for_completion: + analysis = await self._wait_for_analysis_completion(analysis) - return analysis \ No newline at end of file + return analysis diff --git a/vt/feed.py b/vt/feed.py index 53f31c5..230230b 100644 --- a/vt/feed.py +++ b/vt/feed.py @@ -61,10 +61,10 @@ class Feed: """ def __init__( - self, - client: "Client", - feed_type: FeedType, - cursor: typing.Optional[str] = None + self, + client: "Client", + feed_type: FeedType, + cursor: typing.Optional[str] = None, ): """Initializes a Feed object. diff --git a/vt/iterator.py b/vt/iterator.py index 7c2b576..c6c6adf 100644 --- a/vt/iterator.py +++ b/vt/iterator.py @@ -67,12 +67,13 @@ class Iterator: # pylint: disable=line-too-long def __init__( - self, - client: "Client", - path: str, params=None, - cursor: typing.Optional[str] = None, - limit: typing.Optional[int] = None, - batch_size: int = 0 + self, + client: "Client", + path: str, + params=None, + cursor: typing.Optional[str] = None, + limit: typing.Optional[int] = None, + batch_size: int = 0, ): """Initializes an iterator. @@ -113,7 +114,9 @@ def _build_params(self) -> typing.Dict: params["limit"] = self._batch_size return params - def _parse_response(self, json_resp: typing.Dict, batch_cursor: int) -> typing.Tuple[typing.List[typing.Dict], typing.Dict]: + def _parse_response( + self, json_resp: typing.Dict, batch_cursor: int + ) -> typing.Tuple[typing.List[typing.Dict], typing.Dict]: if not isinstance(json_resp.get("data"), list): raise ValueError(f"{self._path} is not a collection") meta = json_resp.get("meta", {}) @@ -121,7 +124,9 @@ def _parse_response(self, json_resp: typing.Dict, batch_cursor: int) -> typing.T return items, meta - async def _get_batch_async(self, batch_cursor: int = 0) -> typing.Tuple[typing.List[typing.Dict], typing.Dict]: + async def _get_batch_async( + self, batch_cursor: int = 0 + ) -> typing.Tuple[typing.List[typing.Dict], typing.Dict]: json_resp = await self._client.get_json_async( self._path, params=self._build_params() ) diff --git a/vt/object.py b/vt/object.py index 58153e1..c5a5338 100644 --- a/vt/object.py +++ b/vt/object.py @@ -31,9 +31,7 @@ class WhistleBlowerDict(collections.UserDict): """ def __init__( - self, - initial_dict: typing.Dict, - on_change_callback: typing.Callable + self, initial_dict: typing.Dict, on_change_callback: typing.Callable ): self._on_change_callback = on_change_callback for k, v in initial_dict.items(): @@ -128,10 +126,10 @@ def from_dict(cls, obj_dict: typing.Dict): return obj def __init__( - self, - obj_type: str, - obj_id: typing.Optional[str] = None, - obj_attributes: typing.Optional[typing.Dict] = None + self, + obj_type: str, + obj_id: typing.Optional[str] = None, + obj_attributes: typing.Optional[typing.Dict] = None, ): """Initializes a VirusTotal API object.""" @@ -210,9 +208,7 @@ def error(self) -> typing.Optional[typing.Dict]: return self._error def get( - self, - attr_name: str, - default: typing.Optional[typing.Any] = None + self, attr_name: str, default: typing.Optional[typing.Any] = None ) -> typing.Any: """Returns an attribute by name. diff --git a/vt/utils.py b/vt/utils.py index 4aef8a6..fda40f1 100644 --- a/vt/utils.py +++ b/vt/utils.py @@ -16,6 +16,7 @@ import asyncio import typing + def make_sync(future: typing.Coroutine): """Utility function that waits for an async call, making it sync.""" try: