From a4874ac39527a0601f2a568a0f70cc1a36b9e6dc Mon Sep 17 00:00:00 2001 From: Jamie Lennox Date: Sat, 16 Jun 2018 19:38:44 +1000 Subject: [PATCH 1/4] Refactor token request During token fetch and token refresh we go through a similar flow to actually perform the request. Refactor existing code to use the same function to perform the request. This should have no change to the existing functionality. --- requests_oauthlib/oauth2_session.py | 63 +++++++++++++++++------------ 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index e5ad72c2..ab7850ab 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -210,24 +210,16 @@ def fetch_token(self, token_url, code=None, authorization_response=None, log.debug('Encoding username, password as Basic auth credentials.') auth = requests.auth.HTTPBasicAuth(username, password) - headers = headers or { - 'Accept': 'application/json', - 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8', - } self.token = {} - if method.upper() == 'POST': - r = self.post(token_url, data=dict(urldecode(body)), - timeout=timeout, headers=headers, auth=auth, - verify=verify, proxies=proxies) - log.debug('Prepared fetch token request body %s', body) - elif method.upper() == 'GET': - # if method is not 'POST', switch body to querystring and GET - r = self.get(token_url, params=dict(urldecode(body)), - timeout=timeout, headers=headers, auth=auth, - verify=verify, proxies=proxies) - log.debug('Prepared fetch token request querystring %s', body) - else: - raise ValueError('The method kwarg must be POST or GET.') + + r = self._auth_request(method.upper(), + token_url, + body, + timeout=timeout, + headers=headers, + auth=auth, + verify=verify, + proxies=proxies) log.debug('Request to fetch token completed with status %s.', r.status_code) @@ -286,16 +278,16 @@ def refresh_token(self, token_url, refresh_token=None, body='', auth=None, refresh_token=refresh_token, scope=self.scope, **kwargs) log.debug('Prepared refresh token request body %s', body) - if headers is None: - headers = { - 'Accept': 'application/json', - 'Content-Type': ( - 'application/x-www-form-urlencoded;charset=UTF-8' - ), - } + r = self._auth_request('POST', + token_url, + body, + auth=auth, + timeout=timeout, + headers=headers, + verify=verify, + withhold_token=True, + proxies=proxies) - r = self.post(token_url, data=dict(urldecode(body)), auth=auth, - timeout=timeout, headers=headers, verify=verify, withhold_token=True, proxies=proxies) log.debug('Request to refresh token completed with status %s.', r.status_code) log.debug('Response headers were %s and content %s.', @@ -312,6 +304,25 @@ def refresh_token(self, token_url, refresh_token=None, body='', auth=None, self.token['refresh_token'] = refresh_token return self.token + def _auth_request(self, method, url, body, **kwargs): + method = method.upper() + data = dict(urldecode(body)) + kwargs.setdefault('headers', { + 'Accept': 'application/json', + 'Content-Type': 'application/x-www-form-urlencoded;charset=UTF-8', + }) + + if method == 'POST': + kwargs['data'] = data + log.debug('Prepared fetch token request body %s', body) + elif method == 'GET': + kwargs['params'] = data + log.debug('Prepared fetch token request querystring %s', body) + else: + raise ValueError('The method kwarg must be POST or GET.') + + return self.request(method, url, **kwargs) + def request(self, method, url, data=None, headers=None, withhold_token=False, client_id=None, client_secret=None, **kwargs): """Intercept all requests and add the OAuth 2 token if present.""" From cf1e2feb8b9bca8cb236ac5cb978474987e3a052 Mon Sep 17 00:00:00 2001 From: Jamie Lennox Date: Sat, 16 Jun 2018 19:59:38 +1000 Subject: [PATCH 2/4] Register a new compliance hook for token_request I'm working with an oauth2 server we know is broken, however it works for other systems and the company has no interest in fixing it. It requires that I submit auth in JSON format rather than form encoded. We shouldn't support this specific problem, but if we have a compliance hook that allows me to change the sending format of the payload as required I can fix my problem and potentially others in future. Related to: #244 (but not the reason for the patch) --- requests_oauthlib/oauth2_session.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index ab7850ab..05cbe0f1 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -80,6 +80,7 @@ def __init__(self, client_id=None, client=None, auto_refresh_url=None, 'access_token_response': set(), 'refresh_token_response': set(), 'protected_request': set(), + 'token_request': set(), } def new_state(self): @@ -321,6 +322,9 @@ def _auth_request(self, method, url, body, **kwargs): else: raise ValueError('The method kwarg must be POST or GET.') + for hook in self.compliance_hook['token_request']: + method, url, kwargs = hook(method, url, **kwargs) + return self.request(method, url, **kwargs) def request(self, method, url, data=None, headers=None, withhold_token=False, From c53cb0d43eb92465295d254634caad60a2692bb8 Mon Sep 17 00:00:00 2001 From: Jamie Lennox Date: Sat, 16 Jun 2018 23:47:18 +1000 Subject: [PATCH 3/4] Use requests-mock for oauth testing requests-mock is already a dependency that is being used in the compliance tests. It simplifies the process of testing exactly what was sent to the server and mocking out return values. Use requests-mock for the rest of the tests rather than a series of custom maintained mocks. --- tests/test_oauth1_session.py | 129 +++++++++++++++----------------- tests/test_oauth2_session.py | 141 +++++++++++++++++++++-------------- 2 files changed, 147 insertions(+), 123 deletions(-) diff --git a/tests/test_oauth1_session.py b/tests/test_oauth1_session.py index 183244b2..2f4c454f 100644 --- a/tests/test_oauth1_session.py +++ b/tests/test_oauth1_session.py @@ -3,6 +3,7 @@ import unittest import sys import requests +import requests_mock from io import StringIO from oauthlib.oauth1 import SIGNATURE_TYPE_QUERY, SIGNATURE_TYPE_BODY @@ -62,6 +63,9 @@ "jPkI%2FkWMvpxtMrU3Z3KN31WQ%3D%3D" ) +TEST_URL = 'https://i.b' +TEST_TOKEN_URL = 'https://example.com/token' + class OAuth1SessionTest(unittest.TestCase): @@ -70,30 +74,32 @@ def setUp(self): if not hasattr(self, 'assertIn'): self.assertIn = lambda a, b: self.assertTrue(a in b) + self.requests_mock = requests_mock.mock() + self.requests_mock.start() + self.addCleanup(self.requests_mock.stop) + + def assert_signature(self, signature): + header = self.requests_mock.last_request.headers['Authorization'] + self.assertEqual(signature, header.decode('utf-8')) + + return header + def test_signature_types(self): - def verify_signature(getter): - def fake_send(r, **kwargs): - signature = getter(r) - if isinstance(signature, bytes_type): - signature = signature.decode('utf-8') - self.assertIn('oauth_signature', signature) - resp = mock.MagicMock(spec=requests.Response) - resp.cookies = [] - return resp - return fake_send + self.requests_mock.post(TEST_URL) header = OAuth1Session('foo') - header.send = verify_signature(lambda r: r.headers['Authorization']) - header.post('https://i.b') + header.post(TEST_URL) + self.assertIn('oauth_signature', + self.requests_mock.last_request.headers['Authorization'].decode('utf-8')) query = OAuth1Session('foo', signature_type=SIGNATURE_TYPE_QUERY) - query.send = verify_signature(lambda r: r.url) - query.post('https://i.b') + query.post(TEST_URL) + self.assertIn('oauth_signature', self.requests_mock.last_request.url) body = OAuth1Session('foo', signature_type=SIGNATURE_TYPE_BODY) headers = {'Content-Type': 'application/x-www-form-urlencoded'} - body.send = verify_signature(lambda r: r.body) - body.post('https://i.b', headers=headers, data='') + body.post(TEST_URL, headers=headers, data='') + self.assertIn('oauth_signature', self.requests_mock.last_request.text) @mock.patch('oauthlib.oauth1.rfc5849.generate_timestamp') @mock.patch('oauthlib.oauth1.rfc5849.generate_nonce') @@ -101,18 +107,19 @@ def test_signature_methods(self, generate_nonce, generate_timestamp): if not cryptography: raise unittest.SkipTest('cryptography module is required') + self.requests_mock.post(TEST_URL) generate_nonce.return_value = 'abc' generate_timestamp.return_value = '123' signature = 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", oauth_signature_method="HMAC-SHA1", oauth_consumer_key="foo", oauth_signature="h2sRqLArjhlc5p3FTkuNogVHlKE%3D"' auth = OAuth1Session('foo') - auth.send = self.verify_signature(signature) - auth.post('https://i.b') + auth.post(TEST_URL) + self.assert_signature(signature) signature = 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", oauth_signature_method="PLAINTEXT", oauth_consumer_key="foo", oauth_signature="%26"' auth = OAuth1Session('foo', signature_method=SIGNATURE_PLAINTEXT) - auth.send = self.verify_signature(signature) - auth.post('https://i.b') + auth.post(TEST_URL) + self.assert_signature(signature) signature = ('OAuth ' 'oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", ' @@ -121,30 +128,33 @@ def test_signature_methods(self, generate_nonce, generate_timestamp): ).format(sig=TEST_RSA_OAUTH_SIGNATURE) auth = OAuth1Session('foo', signature_method=SIGNATURE_RSA, rsa_key=TEST_RSA_KEY) - auth.send = self.verify_signature(signature) - auth.post('https://i.b') + auth.post(TEST_URL) + self.assert_signature(signature) @mock.patch('oauthlib.oauth1.rfc5849.generate_timestamp') @mock.patch('oauthlib.oauth1.rfc5849.generate_nonce') def test_binary_upload(self, generate_nonce, generate_timestamp): + self.requests_mock.post(TEST_URL) + generate_nonce.return_value = 'abc' generate_timestamp.return_value = '123' fake_xml = StringIO('hello world') headers = {'Content-Type': 'application/xml'} signature = 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", oauth_signature_method="HMAC-SHA1", oauth_consumer_key="foo", oauth_signature="h2sRqLArjhlc5p3FTkuNogVHlKE%3D"' auth = OAuth1Session('foo') - auth.send = self.verify_signature(signature) - auth.post('https://i.b', headers=headers, files=[('fake', fake_xml)]) + auth.post(TEST_URL, headers=headers, files=[('fake', fake_xml)]) + self.assert_signature(signature) @mock.patch('oauthlib.oauth1.rfc5849.generate_timestamp') @mock.patch('oauthlib.oauth1.rfc5849.generate_nonce') def test_nonascii(self, generate_nonce, generate_timestamp): + self.requests_mock.post('https://i.b') generate_nonce.return_value = 'abc' generate_timestamp.return_value = '123' signature = 'OAuth oauth_nonce="abc", oauth_timestamp="123", oauth_version="1.0", oauth_signature_method="HMAC-SHA1", oauth_consumer_key="foo", oauth_signature="W0haoue5IZAZoaJiYCtfqwMf8x8%3D"' auth = OAuth1Session('foo') - auth.send = self.verify_signature(signature) auth.post('https://i.b?cjk=%E5%95%A6%E5%95%A6') + self.assert_signature(signature) def test_authorization_url(self): auth = OAuth1Session('foo') @@ -165,68 +175,72 @@ def test_parse_response_url(self): def test_fetch_request_token(self): auth = OAuth1Session('foo') - auth.send = self.fake_body('oauth_token=foo') - resp = auth.fetch_request_token('https://example.com/token') + self.requests_mock.post(TEST_TOKEN_URL, text='oauth_token=foo') + resp = auth.fetch_request_token(TEST_TOKEN_URL) self.assertEqual(resp['oauth_token'], 'foo') for k, v in resp.items(): self.assertTrue(isinstance(k, unicode_type)) self.assertTrue(isinstance(v, unicode_type)) + self.assertTrue(self.requests_mock.called_once) def test_fetch_request_token_with_optional_arguments(self): auth = OAuth1Session('foo') - auth.send = self.fake_body('oauth_token=foo') - resp = auth.fetch_request_token('https://example.com/token', - verify=False, stream=True) + self.requests_mock.post(TEST_TOKEN_URL, text='oauth_token=foo') + resp = auth.fetch_request_token(TEST_TOKEN_URL, verify=False, stream=True) self.assertEqual(resp['oauth_token'], 'foo') for k, v in resp.items(): self.assertTrue(isinstance(k, unicode_type)) self.assertTrue(isinstance(v, unicode_type)) + self.assertTrue(self.requests_mock.called_once) + self.assertFalse(self.requests_mock.last_request.verify) def test_fetch_access_token(self): auth = OAuth1Session('foo', verifier='bar') - auth.send = self.fake_body('oauth_token=foo') - resp = auth.fetch_access_token('https://example.com/token') + self.requests_mock.post(TEST_TOKEN_URL, text='oauth_token=foo') + resp = auth.fetch_access_token(TEST_TOKEN_URL) self.assertEqual(resp['oauth_token'], 'foo') for k, v in resp.items(): self.assertTrue(isinstance(k, unicode_type)) self.assertTrue(isinstance(v, unicode_type)) + self.assertTrue(self.requests_mock.called_once) def test_fetch_access_token_with_optional_arguments(self): auth = OAuth1Session('foo', verifier='bar') - auth.send = self.fake_body('oauth_token=foo') - resp = auth.fetch_access_token('https://example.com/token', - verify=False, stream=True) + self.requests_mock.post(TEST_TOKEN_URL, text='oauth_token=foo') + resp = auth.fetch_access_token(TEST_TOKEN_URL, verify=False, stream=True) self.assertEqual(resp['oauth_token'], 'foo') for k, v in resp.items(): self.assertTrue(isinstance(k, unicode_type)) self.assertTrue(isinstance(v, unicode_type)) + self.assertTrue(self.requests_mock.called_once) + self.assertFalse(self.requests_mock.last_request.verify) def _test_fetch_access_token_raises_error(self, auth): """Assert that an error is being raised whenever there's no verifier passed in to the client. """ - auth.send = self.fake_body('oauth_token=foo') + self.requests_mock.post(TEST_TOKEN_URL, text='oauth_token=foo') # Use a try-except block so that we can assert on the exception message # being raised and also keep the Python2.6 compatibility where # assertRaises is not a context manager. try: - auth.fetch_access_token('https://example.com/token') + auth.fetch_access_token(TEST_TOKEN_URL) except ValueError as exc: self.assertEqual('No client verifier has been set.', str(exc)) def test_fetch_token_invalid_response(self): auth = OAuth1Session('foo') - auth.send = self.fake_body('not valid urlencoded response!') - self.assertRaises(ValueError, auth.fetch_request_token, - 'https://example.com/token') + self.requests_mock.post(TEST_TOKEN_URL, + text='not valid urlencoded response!') + self.assertRaises(ValueError, auth.fetch_request_token, TEST_TOKEN_URL) for code in (400, 401, 403): - auth.send = self.fake_body('valid=response', code) + self.requests_mock.post(TEST_TOKEN_URL, status_code=code) # use try/catch rather than self.assertRaises, so we can # assert on the properties of the exception try: - auth.fetch_request_token('https://example.com/token') + auth.fetch_request_token(TEST_TOKEN_URL) except ValueError as err: self.assertEqual(err.status_code, code) self.assertTrue(isinstance(err.response, requests.Response)) @@ -289,13 +303,13 @@ def test_authorized_false_rsa(self): ).format(sig=TEST_RSA_OAUTH_SIGNATURE) sess = OAuth1Session('foo', signature_method=SIGNATURE_RSA, rsa_key=TEST_RSA_KEY) - sess.send = self.verify_signature(signature) self.assertFalse(sess.authorized) def test_authorized_true(self): + self.requests_mock.post(TEST_TOKEN_URL, + text='oauth_token=foo&oauth_token_secret=bar') sess = OAuth1Session('key', 'secret', verifier='bar') - sess.send = self.fake_body('oauth_token=foo&oauth_token_secret=bar') - sess.fetch_access_token('https://example.com/token') + sess.fetch_access_token(TEST_TOKEN_URL) self.assertTrue(sess.authorized) @mock.patch('oauthlib.oauth1.rfc5849.generate_timestamp') @@ -313,26 +327,7 @@ def test_authorized_true_rsa(self, generate_nonce, generate_timestamp): ).format(sig=TEST_RSA_OAUTH_SIGNATURE) sess = OAuth1Session('key', 'secret', signature_method=SIGNATURE_RSA, rsa_key=TEST_RSA_KEY, verifier='bar') - sess.send = self.fake_body('oauth_token=foo&oauth_token_secret=bar') - sess.fetch_access_token('https://example.com/token') + self.requests_mock.post(TEST_TOKEN_URL, + text='oauth_token=foo&oauth_token_secret=bar') + sess.fetch_access_token(TEST_TOKEN_URL) self.assertTrue(sess.authorized) - - def verify_signature(self, signature): - def fake_send(r, **kwargs): - auth_header = r.headers['Authorization'] - if isinstance(auth_header, bytes_type): - auth_header = auth_header.decode('utf-8') - self.assertEqual(auth_header, signature) - resp = mock.MagicMock(spec=requests.Response) - resp.cookies = [] - return resp - return fake_send - - def fake_body(self, body, status_code=200): - def fake_send(r, **kwargs): - resp = mock.MagicMock(spec=requests.Response) - resp.cookies = [] - resp.text = body - resp.status_code = status_code - return resp - return fake_send diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index e5892cab..1ea322af 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -12,20 +12,12 @@ from oauthlib.oauth2 import WebApplicationClient, MobileApplicationClient from oauthlib.oauth2 import LegacyApplicationClient, BackendApplicationClient from requests_oauthlib import OAuth2Session, TokenUpdated +import requests_mock fake_time = time.time() - -def fake_token(token): - def fake_send(r, **kwargs): - resp = mock.MagicMock() - resp.text = json.dumps(token) - return resp - return fake_send - - class OAuth2SessionTest(TestCase): def setUp(self): @@ -48,20 +40,24 @@ def setUp(self): ] self.all_clients = self.clients + [MobileApplicationClient(self.client_id)] - def test_add_token(self): - token = 'Bearer ' + self.token['access_token'] + self.requests_mock = requests_mock.mock() + self.requests_mock.start() + self.addCleanup(self.requests_mock.stop) - def verifier(r, **kwargs): - auth_header = r.headers.get(str('Authorization'), None) - self.assertEqual(auth_header, token) - resp = mock.MagicMock() - resp.cookes = [] - return resp + def test_add_token(self): + self.requests_mock.get('https://i.b', text='Ok') for client in self.all_clients: auth = OAuth2Session(client=client, token=self.token) - auth.send = verifier - auth.get('https://i.b') + resp = auth.get('https://i.b') + self.assertEqual(200, resp.status_code) + + self.assertEqual(len(self.all_clients), + len(self.requests_mock.request_history)) + + token = 'Bearer ' + self.token['access_token'] + for r in self.requests_mock.request_history: + self.assertEqual(token, r.headers.get(str('Authorization'), None)) def test_authorization_url(self): url = 'https://example.com/authorize?foo=bar' @@ -81,58 +77,86 @@ def test_authorization_url(self): self.assertIn('response_type=token', auth_url) @mock.patch("time.time", new=lambda: fake_time) - def test_refresh_token_request(self): + def test_refresh_token_request_no_refresh(self): self.expired_token = dict(self.token) self.expired_token['expires_in'] = '-1' del self.expired_token['expires_at'] - def fake_refresh(r, **kwargs): - if "/refresh" in r.url: - self.assertNotIn("Authorization", r.headers) - resp = mock.MagicMock() - resp.text = json.dumps(self.token) - return resp - # No auto refresh setup for client in self.clients: auth = OAuth2Session(client=client, token=self.expired_token) self.assertRaises(TokenExpiredError, auth.get, 'https://i.b') + self.assertFalse(self.requests_mock.called) + + @mock.patch("time.time", new=lambda: fake_time) + def test_refresh_token_request_refresh_no_update(self): + self.expired_token = dict(self.token) + self.expired_token['expires_in'] = '-1' + del self.expired_token['expires_at'] + + m1 = self.requests_mock.get('https://i.b') + m2 = self.requests_mock.post('https://i.b/refresh', json=self.token) + # Auto refresh but no auto update for client in self.clients: auth = OAuth2Session(client=client, token=self.expired_token, auto_refresh_url='https://i.b/refresh') - auth.send = fake_refresh self.assertRaises(TokenUpdated, auth.get, 'https://i.b') - # Auto refresh and auto update - def token_updater(token): - self.assertEqual(token, self.token) + self.assertFalse(m1.called) + self.assertEquals(len(self.clients), m2.call_count) + + @mock.patch("time.time", new=lambda: fake_time) + def test_refresh_token_request_refresh_and_update(self): + self.expired_token = dict(self.token) + self.expired_token['expires_in'] = '-1' + del self.expired_token['expires_at'] + + m1 = self.requests_mock.get('https://i.b') + m2 = self.requests_mock.post('https://i.b/refresh', json=self.token) + + token_updater = mock.MagicMock() for client in self.clients: auth = OAuth2Session(client=client, token=self.expired_token, auto_refresh_url='https://i.b/refresh', token_updater=token_updater) - auth.send = fake_refresh - auth.get('https://i.b') - - def fake_refresh_with_auth(r, **kwargs): - if "/refresh" in r.url: - self.assertIn("Authorization", r.headers) - encoded = b64encode(b"foo:bar") - content = (b"Basic " + encoded).decode('latin1') - self.assertEqual(r.headers["Authorization"], content) - resp = mock.MagicMock() - resp.text = json.dumps(self.token) - return resp + resp = auth.get('https://i.b') + self.assertEqual(200, resp.status_code) + + self.assertEquals(len(self.clients), m1.call_count) + self.assertEquals(len(self.clients), m2.call_count) + self.assertEquals(len(self.clients), token_updater.call_count) + + @mock.patch("time.time", new=lambda: fake_time) + def test_refresh_token_request_refresh_and_update_2(self): + self.expired_token = dict(self.token) + self.expired_token['expires_in'] = '-1' + del self.expired_token['expires_at'] + + m1 = self.requests_mock.get('https://i.b') + m2 = self.requests_mock.post('https://i.b/refresh', json=self.token) + + token_updater = mock.MagicMock() for client in self.clients: auth = OAuth2Session(client=client, token=self.expired_token, auto_refresh_url='https://i.b/refresh', token_updater=token_updater) - auth.send = fake_refresh_with_auth auth.get('https://i.b', client_id='foo', client_secret='bar') + self.assertEquals(len(self.clients), m1.call_count) + self.assertEquals(len(self.clients), m2.call_count) + self.assertEquals(len(self.clients), token_updater.call_count) + + token = (b"Basic " + b64encode(b"foo:bar")).decode('latin1') + for r in m2.request_history: + self.assertEquals(token, r.headers["Authorization"]) + + for c in token_updater.call_args_list: + self.assertEqual(c, mock.call(self.token)) + @mock.patch("time.time", new=lambda: fake_time) def test_token_from_fragment(self): mobile = MobileApplicationClient(self.client_id) @@ -141,20 +165,27 @@ def test_token_from_fragment(self): self.assertEqual(auth.token_from_fragment(response_url), self.token) @mock.patch("time.time", new=lambda: fake_time) - def test_fetch_token(self): + def test_fetch_token_good(self): url = 'https://example.com/token' + self.requests_mock.post(url, json=self.token) for client in self.clients: auth = OAuth2Session(client=client, token=self.token) - auth.send = fake_token(self.token) self.assertEqual(auth.fetch_token(url), self.token) - error = {'error': 'invalid_request'} + self.assertEqual(len(self.clients), self.requests_mock.call_count) + + @mock.patch("time.time", new=lambda: fake_time) + def test_fetch_token_invalid(self): + url = 'https://example.com/token' + self.requests_mock.post(url, json={'error': 'invalid_request'}) + for client in self.clients: auth = OAuth2Session(client=client, token=self.token) - auth.send = fake_token(error) self.assertRaises(OAuth2Error, auth.fetch_token, url) + self.assertEqual(len(self.clients), self.requests_mock.call_count) + def test_cleans_previous_token_before_fetching_new_one(self): """Makes sure the previous token is cleaned before fetching a new one. @@ -170,12 +201,14 @@ def test_cleans_previous_token_before_fetching_new_one(self): new_token['expires_at'] = now + 3600 url = 'https://example.com/token' + self.requests_mock.post(url, json=new_token) + with mock.patch('time.time', lambda: now): for client in self.clients: auth = OAuth2Session(client=client, token=self.token) - auth.send = fake_token(new_token) self.assertEqual(auth.fetch_token(url), new_token) + self.assertTrue(len(self.clients), self.requests_mock.call_count) def test_web_app_fetch_token(self): # Ensure the state parameter is used, see issue #105. @@ -229,17 +262,13 @@ def test_authorized_false(self): @mock.patch("time.time", new=lambda: fake_time) def test_authorized_true(self): - def fake_token(token): - def fake_send(r, **kwargs): - resp = mock.MagicMock() - resp.text = json.dumps(token) - return resp - return fake_send url = 'https://example.com/token' + self.requests_mock.post(url, json=self.token) for client in self.clients: sess = OAuth2Session(client=client) - sess.send = fake_token(self.token) self.assertFalse(sess.authorized) sess.fetch_token(url) self.assertTrue(sess.authorized) + + self.assertEqual(len(self.clients), self.requests_mock.call_count) From 94bec1b7b2bb0b927abab3315483f45c7ff06000 Mon Sep 17 00:00:00 2001 From: Jamie Lennox Date: Sat, 16 Jun 2018 21:16:15 +1000 Subject: [PATCH 4/4] Throw an exception on a bad http return If the token request fails with an error such as Unauthenticated or anything else the existing code doesn't check for this code and will try and retrieve the token anyway. This gives a somewhat difficult to debug key error when the response doesn't contain the token data. This is wrong, we should always be checking the response code before trusting the response data. This is a slight change in behaviour, we now return this exception instead of a KeyError, however the other exception was difficult to catch. This reuses the failure exception from OAuth1 and makes that public. Closes: #302 --- requests_oauthlib/__init__.py | 1 + requests_oauthlib/exc.py | 10 ++++++++++ requests_oauthlib/oauth1_session.py | 15 ++------------- requests_oauthlib/oauth2_session.py | 10 +++++++++- tests/test_oauth2_session.py | 18 +++++++++++++++++- 5 files changed, 39 insertions(+), 15 deletions(-) create mode 100644 requests_oauthlib/exc.py diff --git a/requests_oauthlib/__init__.py b/requests_oauthlib/__init__.py index 1bb919ee..1f9418ae 100644 --- a/requests_oauthlib/__init__.py +++ b/requests_oauthlib/__init__.py @@ -1,5 +1,6 @@ import logging +from .exc import TokenRequestDenied from .oauth1_auth import OAuth1 from .oauth1_session import OAuth1Session from .oauth2_auth import OAuth2 diff --git a/requests_oauthlib/exc.py b/requests_oauthlib/exc.py new file mode 100644 index 00000000..7be01912 --- /dev/null +++ b/requests_oauthlib/exc.py @@ -0,0 +1,10 @@ +class TokenRequestDenied(ValueError): + + def __init__(self, message, response): + super(TokenRequestDenied, self).__init__(message) + self.response = response + + @property + def status_code(self): + """For backwards-compatibility purposes""" + return self.response.status_code diff --git a/requests_oauthlib/oauth1_session.py b/requests_oauthlib/oauth1_session.py index 53b7b6d1..f6b94421 100644 --- a/requests_oauthlib/oauth1_session.py +++ b/requests_oauthlib/oauth1_session.py @@ -14,6 +14,7 @@ ) import requests +from . import exc from . import OAuth1 @@ -29,18 +30,6 @@ def urldecode(body): return json.loads(body) -class TokenRequestDenied(ValueError): - - def __init__(self, message, response): - super(TokenRequestDenied, self).__init__(message) - self.response = response - - @property - def status_code(self): - """For backwards-compatibility purposes""" - return self.response.status_code - - class TokenMissing(ValueError): def __init__(self, message, response): super(TokenMissing, self).__init__(message) @@ -365,7 +354,7 @@ def _fetch_token(self, url, **request_kwargs): if r.status_code >= 400: error = "Token request failed with code %s, response was '%s'." - raise TokenRequestDenied(error % (r.status_code, r.text), r) + raise exc.TokenRequestDenied(error % (r.status_code, r.text), r) log.debug('Decoding token from response "%s"', r.text) try: diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index 05cbe0f1..82f57bb6 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -9,6 +9,8 @@ log = logging.getLogger(__name__) +from . import exc + class TokenUpdated(Warning): def __init__(self, token): @@ -325,7 +327,13 @@ def _auth_request(self, method, url, body, **kwargs): for hook in self.compliance_hook['token_request']: method, url, kwargs = hook(method, url, **kwargs) - return self.request(method, url, **kwargs) + r = self.request(method, url, **kwargs) + + if not r.ok: + error = "Token request failed with code %s, response was '%s'." + raise exc.TokenRequestDenied(error % (r.status_code, r.text), r) + + return r def request(self, method, url, data=None, headers=None, withhold_token=False, client_id=None, client_secret=None, **kwargs): diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index 1ea322af..011256ea 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -11,7 +11,7 @@ from oauthlib.oauth2 import MismatchingStateError from oauthlib.oauth2 import WebApplicationClient, MobileApplicationClient from oauthlib.oauth2 import LegacyApplicationClient, BackendApplicationClient -from requests_oauthlib import OAuth2Session, TokenUpdated +from requests_oauthlib import OAuth2Session, TokenUpdated, TokenRequestDenied import requests_mock @@ -272,3 +272,19 @@ def test_authorized_true(self): self.assertTrue(sess.authorized) self.assertEqual(len(self.clients), self.requests_mock.call_count) + + def test_token_fetch_invalid_status_code(self): + url = 'https://example.com/token' + self.requests_mock.post(url, + json={'message': 'Failure'}, + status_code=403) + + for client in self.clients: + sess = OAuth2Session(client=client) + self.assertRaises( + TokenRequestDenied, + sess.fetch_token, + url + ) + + self.assertEqual(len(self.clients), self.requests_mock.call_count)