From 20bb733258fbfc1041fe271f04fa81e0f9ee81f7 Mon Sep 17 00:00:00 2001 From: Antti Haapala Date: Tue, 29 Aug 2017 20:19:19 +0300 Subject: [PATCH 1/3] add ``require_csrf`` argument to ``add_jsonrpc_endpoint`` --- CONTRIBUTORS.txt | 2 + pyramid_rpc/jsonrpc.py | 22 +++++- pyramid_rpc/tests/test_jsonrpc.py | 121 +++++++++++++++++++++++++++++- 3 files changed, 140 insertions(+), 5 deletions(-) diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 082648a..ce6564d 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -108,3 +108,5 @@ Contributors - Donald Stufft, 8/11/2015 - Ben Holzman, 11/17/2015 + +- Antti Haapala, 8/29/2017 diff --git a/pyramid_rpc/jsonrpc.py b/pyramid_rpc/jsonrpc.py index e060fc3..d84b246 100644 --- a/pyramid_rpc/jsonrpc.py +++ b/pyramid_rpc/jsonrpc.py @@ -316,10 +316,11 @@ def batched_request_view(request): class Endpoint(object): - def __init__(self, name, default_mapper, default_renderer): + def __init__(self, name, default_mapper, default_renderer, require_csrf): self.name = name self.default_mapper = default_mapper self.default_renderer = default_renderer + self.require_csrf = require_csrf def add_jsonrpc_endpoint(config, name, *args, **kw): @@ -341,17 +342,24 @@ def add_jsonrpc_endpoint(config, name, *args, **kw): string name of the renderer, registered via :meth:`pyramid.config.Configurator.add_renderer`. + ``require_csrf`` + + If this argument is specified and is not ``None``, the value will + be passed as the ``require_csrf`` argument to each of the endpoint's + methods, and the batch request view and error view registration. + A JSON-RPC method also accepts all of the arguments supplied to :meth:`pyramid.config.Configurator.add_route`. - """ default_mapper = kw.pop('default_mapper', MapplyViewMapper) default_renderer = kw.pop('default_renderer', DEFAULT_RENDERER) + require_csrf = kw.pop('require_csrf', None) endpoint = Endpoint( name, default_mapper=default_mapper, default_renderer=default_renderer, + require_csrf=require_csrf ) config.registry.jsonrpc_endpoints[name] = endpoint @@ -363,9 +371,11 @@ def add_jsonrpc_endpoint(config, name, *args, **kw): kw['jsonrpc_batched'] = True kw['renderer'] = null_renderer config.add_view(batched_request_view, route_name=name, - permission=NO_PERMISSION_REQUIRED, **kw) + permission=NO_PERMISSION_REQUIRED, + require_csrf=require_csrf, **kw) config.add_view(exception_view, route_name=name, context=Exception, - permission=NO_PERMISSION_REQUIRED) + permission=NO_PERMISSION_REQUIRED, + require_csrf=require_csrf) def add_jsonrpc_method(config, view, **kw): @@ -416,6 +426,10 @@ def add_jsonrpc_method(config, view, **kw): mapper = endpoint.default_mapper kw['mapper'] = mapper + if 'require_csrf' not in kw and endpoint.require_csrf is not None: + # only override mapper if not supplied + kw['require_csrf'] = endpoint.require_csrf + renderer = kw.pop('renderer', None) if renderer is None: renderer = endpoint.default_renderer diff --git a/pyramid_rpc/tests/test_jsonrpc.py b/pyramid_rpc/tests/test_jsonrpc.py index af2f555..2a9a8b4 100644 --- a/pyramid_rpc/tests/test_jsonrpc.py +++ b/pyramid_rpc/tests/test_jsonrpc.py @@ -2,10 +2,18 @@ import unittest from pyramid import testing - +from pyramid.exceptions import BadCSRFToken from webtest import TestApp +class DummySessionFactory(object): + def __init__(self, request): + pass + + def get_csrf_token(self): + return 'abc' + + class Test_add_jsonrpc_method(unittest.TestCase): def setUp(self): @@ -537,6 +545,117 @@ def view(request, a): result = self._callFUT(app, 'dummy', [val]) self.assertEqual(result['result'], val) + def test_require_csrf_False(self): + config = self.config + def view(request): + return 'this must return' + + config = self.config + config.include('pyramid_rpc.jsonrpc') + config.set_default_csrf_options(require_csrf=True) + config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=False) + config.add_jsonrpc_method(view, endpoint='rpc', method='dummy') + app = config.make_wsgi_app() + app = TestApp(app) + result = self._callFUT(app, 'dummy', [], expect_error=False) + self.assertEqual(result['result'], 'this must return') + + def test_require_csrf_True(self): + config = self.config + def view(request): + return 'this must not return' + + config = self.config + config.include('pyramid_rpc.jsonrpc') + config.set_session_factory(DummySessionFactory) + config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True) + config.add_jsonrpc_method(view, endpoint='rpc', method='dummy') + app = config.make_wsgi_app() + app = TestApp(app) + with self.assertRaises(BadCSRFToken): + result = self._callFUT(app, 'dummy', []) + + def test_require_csrf_overrideable_on_method(self): + config = self.config + def view(request): + return 'this must return' + + config = self.config + config.include('pyramid_rpc.jsonrpc') + config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True) + config.add_jsonrpc_method(view, endpoint='rpc', method='dummy', require_csrf=False) + app = config.make_wsgi_app() + app = TestApp(app) + result = self._callFUT(app, 'dummy', [], expect_error=False) + self.assertEqual(result['result'], 'this must return') + + def test_error_require_csrf_False(self): + def view(request): + raise Exception + config = self.config + config.include('pyramid_rpc.jsonrpc') + config.set_default_csrf_options(require_csrf=True) + config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=False) + config.add_jsonrpc_method(view, endpoint='rpc', method='err') + app = config.make_wsgi_app() + app = TestApp(app) + result = self._callFUT(app, 'err', [], id=None, expect_error=True) + self.assertEqual(result['error']['code'], -32603) + + def test_error_require_csrf_True(self): + def view(request): + raise Exception + config = self.config + config.set_session_factory(DummySessionFactory) + config.include('pyramid_rpc.jsonrpc') + config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True) + config.add_jsonrpc_method(view, endpoint='rpc', method='err') + app = config.make_wsgi_app() + app = TestApp(app) + with self.assertRaises(BadCSRFToken): + result = self._callFUT(app, 'err', [], id=None, expect_error=True) + + def test_it_with_batched_requests_require_csrf_False(self): + def view(request, a, b): + return [a, b] + config = self.config + config.include('pyramid_rpc.jsonrpc') + config.set_default_csrf_options(require_csrf=True) + config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=False) + config.add_jsonrpc_method(view, endpoint='rpc', method='dummy') + app = config.make_wsgi_app() + app = TestApp(app) + body = [ + {'id': 1, 'jsonrpc': '2.0', 'method': 'dummy', 'params': [2, 3]}, + {'id': 2, 'jsonrpc': '2.0', 'method': 'dummy', 'params': {'a': 3, 'b': 2}}, + ] + resp = app.post('/api/jsonrpc', content_type='application/json', + params=json.dumps(body)) + self.assertEqual(resp.status_int, 200) + result = resp.json + result1 = [r for r in result if r['id'] == 1][0] + result2 = [r for r in result if r['id'] == 2][0] + self.assertEqual(result1, {'id': 1, 'jsonrpc': '2.0', 'result': [2, 3]}) + self.assertEqual(result2, {'id': 2, 'jsonrpc': '2.0', 'result': [3, 2]}) + + def test_it_with_batched_requests_require_csrf_True_must_fail(self): + def view(request, a, b): + return [a, b] + config = self.config + config.set_session_factory(DummySessionFactory) + config.include('pyramid_rpc.jsonrpc') + config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True) + config.add_jsonrpc_method(view, endpoint='rpc', method='dummy') + app = config.make_wsgi_app() + app = TestApp(app) + body = [ + {'id': 1, 'jsonrpc': '2.0', 'method': 'dummy', 'params': [2, 3]}, + {'id': 2, 'jsonrpc': '2.0', 'method': 'dummy', 'params': {'a': 3, 'b': 2}}, + ] + with self.assertRaises(BadCSRFToken): + resp = app.post('/api/jsonrpc', content_type='application/json', + params=json.dumps(body)) + class TestGET(unittest.TestCase): From 3aa5dc30084224905b89085dbaab728d88f02d37 Mon Sep 17 00:00:00 2001 From: Antti Haapala Date: Tue, 29 Aug 2017 21:12:54 +0300 Subject: [PATCH 2/3] try to fix coverage --- pyramid_rpc/tests/test_jsonrpc.py | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/pyramid_rpc/tests/test_jsonrpc.py b/pyramid_rpc/tests/test_jsonrpc.py index 2a9a8b4..ef15aed 100644 --- a/pyramid_rpc/tests/test_jsonrpc.py +++ b/pyramid_rpc/tests/test_jsonrpc.py @@ -546,7 +546,6 @@ def view(request, a): self.assertEqual(result['result'], val) def test_require_csrf_False(self): - config = self.config def view(request): return 'this must return' @@ -561,25 +560,19 @@ def view(request): self.assertEqual(result['result'], 'this must return') def test_require_csrf_True(self): - config = self.config - def view(request): - return 'this must not return' - config = self.config config.include('pyramid_rpc.jsonrpc') config.set_session_factory(DummySessionFactory) config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True) - config.add_jsonrpc_method(view, endpoint='rpc', method='dummy') + config.add_jsonrpc_method(lambda request: 'is not called', endpoint='rpc', method='dummy') app = config.make_wsgi_app() app = TestApp(app) with self.assertRaises(BadCSRFToken): result = self._callFUT(app, 'dummy', []) def test_require_csrf_overrideable_on_method(self): - config = self.config def view(request): return 'this must return' - config = self.config config.include('pyramid_rpc.jsonrpc') config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True) @@ -590,30 +583,24 @@ def view(request): self.assertEqual(result['result'], 'this must return') def test_error_require_csrf_False(self): - def view(request): - raise Exception config = self.config config.include('pyramid_rpc.jsonrpc') config.set_default_csrf_options(require_csrf=True) config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=False) - config.add_jsonrpc_method(view, endpoint='rpc', method='err') app = config.make_wsgi_app() app = TestApp(app) - result = self._callFUT(app, 'err', [], id=None, expect_error=True) - self.assertEqual(result['error']['code'], -32603) + result = self._callFUT(app, 'err', [], expect_error=True) + self.assertEqual(result['error']['code'], -32601) # invalid method def test_error_require_csrf_True(self): - def view(request): - raise Exception config = self.config config.set_session_factory(DummySessionFactory) config.include('pyramid_rpc.jsonrpc') config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True) - config.add_jsonrpc_method(view, endpoint='rpc', method='err') app = config.make_wsgi_app() app = TestApp(app) with self.assertRaises(BadCSRFToken): - result = self._callFUT(app, 'err', [], id=None, expect_error=True) + result = self._callFUT(app, 'err', [], expect_error=True) def test_it_with_batched_requests_require_csrf_False(self): def view(request, a, b): From 5b01847adfb25548ea4e01a9076a33d9e4fd27ff Mon Sep 17 00:00:00 2001 From: Antti Haapala Date: Tue, 29 Aug 2017 21:28:47 +0300 Subject: [PATCH 3/3] fix coverage once more, better tests yet --- pyramid_rpc/tests/test_jsonrpc.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pyramid_rpc/tests/test_jsonrpc.py b/pyramid_rpc/tests/test_jsonrpc.py index ef15aed..6583c6b 100644 --- a/pyramid_rpc/tests/test_jsonrpc.py +++ b/pyramid_rpc/tests/test_jsonrpc.py @@ -564,11 +564,12 @@ def test_require_csrf_True(self): config.include('pyramid_rpc.jsonrpc') config.set_session_factory(DummySessionFactory) config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True) - config.add_jsonrpc_method(lambda request: 'is not called', endpoint='rpc', method='dummy') + config.add_jsonrpc_method(lambda: 'not actually called', + endpoint='rpc', method='dummy') app = config.make_wsgi_app() app = TestApp(app) with self.assertRaises(BadCSRFToken): - result = self._callFUT(app, 'dummy', []) + self._callFUT(app, 'dummy', []) def test_require_csrf_overrideable_on_method(self): def view(request): @@ -576,7 +577,8 @@ def view(request): config = self.config config.include('pyramid_rpc.jsonrpc') config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True) - config.add_jsonrpc_method(view, endpoint='rpc', method='dummy', require_csrf=False) + config.add_jsonrpc_method(view, endpoint='rpc', + method='dummy', require_csrf=False) app = config.make_wsgi_app() app = TestApp(app) result = self._callFUT(app, 'dummy', [], expect_error=False) @@ -600,7 +602,7 @@ def test_error_require_csrf_True(self): app = config.make_wsgi_app() app = TestApp(app) with self.assertRaises(BadCSRFToken): - result = self._callFUT(app, 'err', [], expect_error=True) + self._callFUT(app, 'err', [], expect_error=True) def test_it_with_batched_requests_require_csrf_False(self): def view(request, a, b): @@ -626,13 +628,11 @@ def view(request, a, b): self.assertEqual(result2, {'id': 2, 'jsonrpc': '2.0', 'result': [3, 2]}) def test_it_with_batched_requests_require_csrf_True_must_fail(self): - def view(request, a, b): - return [a, b] config = self.config config.set_session_factory(DummySessionFactory) config.include('pyramid_rpc.jsonrpc') config.add_jsonrpc_endpoint('rpc', '/api/jsonrpc', require_csrf=True) - config.add_jsonrpc_method(view, endpoint='rpc', method='dummy') + config.add_jsonrpc_method(lambda: 'this is not actually called', endpoint='rpc', method='dummy') app = config.make_wsgi_app() app = TestApp(app) body = [ @@ -640,8 +640,8 @@ def view(request, a, b): {'id': 2, 'jsonrpc': '2.0', 'method': 'dummy', 'params': {'a': 3, 'b': 2}}, ] with self.assertRaises(BadCSRFToken): - resp = app.post('/api/jsonrpc', content_type='application/json', - params=json.dumps(body)) + app.post('/api/jsonrpc', content_type='application/json', + params=json.dumps(body)) class TestGET(unittest.TestCase):