From 59aa351ca815c6112831a3093645a523b57563c5 Mon Sep 17 00:00:00 2001 From: "marco.santamaria" Date: Wed, 13 May 2015 11:26:35 +0200 Subject: [PATCH] Added support for json post data in filter_post_data_parameters. --- tests/integration/test_filter.py | 13 +++++++++++++ tests/unit/test_filters.py | 27 +++++++++++++++++++++++++++ vcr/filters.py | 27 ++++++++++++++++++--------- 3 files changed, 58 insertions(+), 9 deletions(-) diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index 2c5bae34..0a5232d0 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -4,6 +4,7 @@ from six.moves.urllib.parse import urlencode from six.moves.urllib.error import HTTPError import vcr +import json def _request_with_auth(url, username, password): @@ -66,6 +67,18 @@ def test_filter_post_data(tmpdir): assert b'id=secret' not in cass.requests[0].body +def test_filter_json_post_data(tmpdir): + data = json.dumps({'id': 'secret', 'foo': 'bar'}).encode('utf-8') + request = Request('http://httpbin.org/post', data=data) + request.add_header('Content-Type', 'application/json') + + cass_file = str(tmpdir.join('filter_jpd.yaml')) + with vcr.use_cassette(cass_file, filter_post_data_parameters=['id']): + urlopen(request) + with vcr.use_cassette(cass_file, filter_post_data_parameters=['id']) as cass: + assert b'"id": "secret"' not in cass.requests[0].body + + def test_filter_callback(tmpdir): url = 'http://httpbin.org/get' cass_file = str(tmpdir.join('basic_auth_filter.yaml')) diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index 546579ad..2bab38c0 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -4,6 +4,7 @@ remove_post_data_parameters ) from vcr.request import Request +import json def test_remove_headers(): @@ -67,3 +68,29 @@ def test_remove_nonexistent_post_data_parameters(): request = Request('POST', 'http://google.com', body, {}) remove_post_data_parameters(request, ['id']) assert request.body == b'' + + +def test_remove_json_post_data_parameters(): + body = b'{"id": "secret", "foo": "bar", "baz": "qux"}' + request = Request('POST', 'http://google.com', body, {}) + request.add_header('Content-Type', 'application/json') + remove_post_data_parameters(request, ['id']) + request_body_json = json.loads(request.body.decode('utf-8')) + expected_json = json.loads(b'{"foo": "bar", "baz": "qux"}'.decode('utf-8')) + assert request_body_json == expected_json + + +def test_remove_all_json_post_data_parameters(): + body = b'{"id": "secret", "foo": "bar"}' + request = Request('POST', 'http://google.com', body, {}) + request.add_header('Content-Type', 'application/json') + remove_post_data_parameters(request, ['id', 'foo']) + assert request.body == b'{}' + + +def test_remove_nonexistent_json_post_data_parameters(): + body = b'{}' + request = Request('POST', 'http://google.com', body, {}) + request.add_header('Content-Type', 'application/json') + remove_post_data_parameters(request, ['id']) + assert request.body == b'{}' diff --git a/vcr/filters.py b/vcr/filters.py index 936bf506..72178131 100644 --- a/vcr/filters.py +++ b/vcr/filters.py @@ -5,6 +5,7 @@ except ImportError: from backport_collections import OrderedDict import copy +import json def remove_headers(request, headers_to_remove): @@ -31,13 +32,21 @@ def remove_query_parameters(request, query_parameters_to_remove): def remove_post_data_parameters(request, post_data_parameters_to_remove): if request.method == 'POST' and not isinstance(request.body, BytesIO): - post_data = OrderedDict() - for k, sep, v in [p.partition(b'=') for p in request.body.split(b'&')]: - if k in post_data: - post_data[k].append(v) - elif len(k) > 0 and k.decode('utf-8') not in post_data_parameters_to_remove: - post_data[k] = [v] - request.body = b'&'.join( - b'='.join([k, v]) - for k, vals in post_data.items() for v in vals) + if ('Content-Type' in request.headers and + request.headers['Content-Type'] == 'application/json'): + json_data = json.loads(request.body.decode('utf-8')) + for k in list(json_data.keys()): + if k in post_data_parameters_to_remove: + del json_data[k] + request.body = json.dumps(json_data).encode('utf-8') + else: + post_data = OrderedDict() + for k, sep, v in [p.partition(b'=') for p in request.body.split(b'&')]: + if k in post_data: + post_data[k].append(v) + elif len(k) > 0 and k.decode('utf-8') not in post_data_parameters_to_remove: + post_data[k] = [v] + request.body = b'&'.join( + b'='.join([k, v]) + for k, vals in post_data.items() for v in vals) return request