diff --git a/pyproject.toml b/pyproject.toml index 529f2fb..e916a19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ show_error_codes = true warn_return_any = true strict_optional = true disallow_incomplete_defs = true -exclude_gitignore = true exclude = ["tests"] [tool.ruff] diff --git a/src/pardner/services/__init__.py b/src/pardner/services/__init__.py index 09a3c85..45a6225 100644 --- a/src/pardner/services/__init__.py +++ b/src/pardner/services/__init__.py @@ -5,4 +5,5 @@ from pardner.services.base import ( UnsupportedVerticalException as UnsupportedVerticalException, ) +from pardner.services.strava import StravaTransferService as StravaTransferService from pardner.services.tumblr import TumblrTransferService as TumblrTransferService diff --git a/src/pardner/services/base.py b/src/pardner/services/base.py index 4a3db36..728565d 100644 --- a/src/pardner/services/base.py +++ b/src/pardner/services/base.py @@ -3,6 +3,7 @@ from requests_oauthlib import OAuth2Session +from pardner.services.utils import scope_as_set, scope_as_string from pardner.verticals import Vertical @@ -47,6 +48,7 @@ def __init__( client_secret: str, redirect_uri: str, supported_verticals: set[Vertical], + state: Optional[str] = None, verticals: set[Vertical] = set(), ) -> None: """ @@ -58,16 +60,18 @@ def __init__( :param client_secret: The `client_secret` paired to the `client_id`. :param redirect_uri: The registered callback URI. :param supported_verticals: The `Vertical`s that can be fetched on the service. + :param state: State string used to prevent CSRF and identify flow. :param verticals: The `Vertical`s for which the transfer service has appropriate scope to fetch. """ - self._oAuth2Session = OAuth2Session( - client_id=client_id, redirect_uri=redirect_uri - ) self._client_secret = client_secret self._supported_verticals = supported_verticals self._service_name = service_name self._verticals = verticals + self._oAuth2Session = OAuth2Session( + client_id=client_id, redirect_uri=redirect_uri, state=state + ) + self.scope = self.scope_for_verticals(verticals) @property def name(self) -> str: @@ -75,11 +79,23 @@ def name(self) -> str: @property def scope(self) -> set[str]: - return self._oAuth2Session.scope if self._oAuth2Session.scope else set() + return ( + scope_as_set(self._oAuth2Session.scope) + if self._oAuth2Session.scope + else set() + ) @scope.setter def scope(self, new_scope: Iterable[str]) -> None: - self._oAuth2Session.scope = set(new_scope) + """ + Sets the scope of the transfer service flow. + Some services have specific requirements for the format of the scope + string (e.g., scopes have to be comma separated, or `+` separated). + + :param new_scope: The new scopes that should be set for the transfer + service. + """ + self._oAuth2Session.scope = scope_as_string(new_scope) @property def verticals(self) -> set[Vertical]: @@ -118,22 +134,23 @@ def add_verticals( """ new_verticals = set(verticals) - self.verticals new_scopes = self.scope_for_verticals(new_verticals) - original_scopes: set[str] = self.scope if self.scope else set() - if not new_scopes.issubset(original_scopes) and not should_reauth: + if not new_scopes.issubset(self.scope) and not should_reauth: raise InsufficientScopeException(verticals, self.name) - elif not new_scopes.issubset(original_scopes): + elif not new_scopes.issubset(self.scope): self.verticals = new_verticals | self.verticals del self._oAuth2Session.access_token - self.scope = original_scopes | new_scopes + self.scope = self.scope | new_scopes return False self.verticals = new_verticals | self.verticals return True - @abstractmethod def fetch_token( - self, code: Optional[str] = None, authorization_response: Optional[str] = None + self, + code: Optional[str] = None, + authorization_response: Optional[str] = None, + include_client_id: bool = False, ) -> dict[str, Any]: """ Once the end-user authorizes the application to access their data, the @@ -147,12 +164,18 @@ def fetch_token( browser redirected to. :param authorization_response: the URL (with parameters) the end-user's browser redirected to after authorization. + :param include_client_id: whether or not to send the client ID with the token request :returns: the authorization URL and state, respectively. """ - pass + return self._oAuth2Session.fetch_token( + token_url=self._token_url, + code=code, + authorization_response=authorization_response, + include_client_id=include_client_id, + client_secret=self._client_secret, + ) - @abstractmethod def authorization_url(self) -> tuple[str, str]: """ Builds the authorization URL and state. Once the end-user (i.e., resource owner) @@ -160,7 +183,7 @@ def authorization_url(self) -> tuple[str, str]: :returns: the authorization URL and state, respectively. """ - pass + return self._oAuth2Session.authorization_url(self._authorization_url) @abstractmethod def scope_for_verticals(self, verticals: Iterable[Vertical]) -> set[str]: diff --git a/src/pardner/services/strava.py b/src/pardner/services/strava.py new file mode 100644 index 0000000..75db660 --- /dev/null +++ b/src/pardner/services/strava.py @@ -0,0 +1,61 @@ +from typing import Any, Iterable, Optional, override + +from pardner.services.base import BaseTransferService, UnsupportedVerticalException +from pardner.services.utils import scope_as_set, scope_as_string +from pardner.verticals import Vertical + + +class StravaTransferService(BaseTransferService): + """ + Class responsible for obtaining end-user authorization to make requests to + Strava's API. + See API documentation: https://developers.strava.com/docs/reference/ + """ + + _authorization_url = 'https://www.strava.com/oauth/authorize' + _token_url = 'https://www.strava.com/oauth/token' + + def __init__( + self, + client_id: str, + client_secret: str, + redirect_uri: str, + state: Optional[str] = None, + verticals: set[Vertical] = set(), + ) -> None: + super().__init__( + service_name='Strava', + client_id=client_id, + client_secret=client_secret, + redirect_uri=redirect_uri, + state=state, + supported_verticals={Vertical.FeedPost}, + verticals=verticals, + ) + + @property + def scope(self) -> set[str]: + return scope_as_set(self._oAuth2Session.scope, delimiter=',') + + @scope.setter + def scope(self, new_scope: Iterable[str] | str) -> None: + self._oAuth2Session.scope = scope_as_string(new_scope, delimiter=',') + + @override + def fetch_token( + self, + code: Optional[str] = None, + authorization_response: Optional[str] = None, + include_client_id: bool = True, + ) -> dict[str, Any]: + return super().fetch_token(code, authorization_response, include_client_id) + + @override + def scope_for_verticals(self, verticals: Iterable[Vertical]) -> set[str]: + sub_scopes: set[str] = set() + for vertical in verticals: + if vertical not in self._supported_verticals: + raise UnsupportedVerticalException([vertical], self._service_name) + if vertical == Vertical.FeedPost: + sub_scopes.update(['activity:read', 'profile:read_all']) + return sub_scopes diff --git a/src/pardner/services/tumblr.py b/src/pardner/services/tumblr.py index 8cb48bb..246b422 100644 --- a/src/pardner/services/tumblr.py +++ b/src/pardner/services/tumblr.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, Optional +from typing import Any, Iterable, Optional, override from pardner.services import BaseTransferService from pardner.verticals import Vertical @@ -19,6 +19,7 @@ def __init__( client_id: str, client_secret: str, redirect_uri: str, + state: Optional[str] = None, verticals: set[Vertical] = set(), ) -> None: super().__init__( @@ -26,25 +27,21 @@ def __init__( client_id=client_id, client_secret=client_secret, redirect_uri=redirect_uri, + state=state, supported_verticals={Vertical.FeedPost}, verticals=verticals, ) + @override def scope_for_verticals(self, verticals: Iterable[Vertical]) -> set[str]: # Tumblr only needs 'base' for read access requests return {'base'} - def authorization_url(self) -> tuple[str, str]: - return self._oAuth2Session.authorization_url(self._authorization_url) - + @override def fetch_token( - self, code: Optional[str] = None, authorization_response: Optional[str] = None + self, + code: Optional[str] = None, + authorization_response: Optional[str] = None, + include_client_id: bool = True, ) -> dict[str, Any]: - # Requires client_id - return self._oAuth2Session.fetch_token( - token_url=self._token_url, - code=code, - authorization_response=authorization_response, - include_client_id=True, - client_secret=self._client_secret, - ) + return super().fetch_token(code, authorization_response, include_client_id) diff --git a/src/pardner/services/utils.py b/src/pardner/services/utils.py new file mode 100644 index 0000000..3519ee4 --- /dev/null +++ b/src/pardner/services/utils.py @@ -0,0 +1,34 @@ +from typing import Any + + +def scope_as_string(scopes: Any, delimiter: str = ' ') -> str | None: + """ + Converts a sequence of individual scopes into a single scope string. + + :param scopes: a sequence of scopes as strings or a scope string. + :param delimiter: the string used to separate individual scopes. Defaults to single space. + + :returns: a string containing all scopes. + :raises :class:ValueError: if `scopes` is neither a string nor a sequence of strings + """ + if isinstance(scopes, str) or scopes is None: + return scopes + elif isinstance(scopes, (set, tuple, list)): + return delimiter.join([str(s) for s in sorted(scopes)]) + raise ValueError(f'Invalid scope ({scopes}), must be string, tuple, set, or list.') + + +def scope_as_set(scope: Any, delimiter: str = ' ') -> set[str]: + """ + Splits a scope with potentially more than one scope into a set of scopes. + + :param scope: a string with one or more scopes. + :param delimiter: the string used to separate individual scopes. Defaults to single space. + + :returns: a set of scopes. + """ + if isinstance(scope, (tuple, list, set)): + return {str(s) for s in scope} + elif scope is None: + return set() + return set(scope.strip().split(delimiter)) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_transfer_services/__init__.py b/tests/test_transfer_services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_transfer_services/conftest.py b/tests/test_transfer_services/conftest.py new file mode 100644 index 0000000..e73784f --- /dev/null +++ b/tests/test_transfer_services/conftest.py @@ -0,0 +1,34 @@ +import pytest + +from pardner.services.strava import StravaTransferService +from pardner.services.tumblr import TumblrTransferService +from pardner.verticals.base import Vertical + + +@pytest.fixture +def mock_oAuth2Session(mocker): + mock_oauth2session_request = mocker.patch('requests_oauthlib.OAuth2Session.request') + mock_client_parse_request_body_response = mocker.patch( + 'oauthlib.oauth2.rfc6749.clients.WebApplicationClient.parse_request_body_response' + ) + return [mock_oauth2session_request, mock_client_parse_request_body_response] + + +@pytest.fixture +def mock_vertical(): + Vertical.NEW_VERTICAL = 'new_vertical' + Vertical.NEW_VERTICAL_EXTRA_SCOPE = 'new_vertical_unsupported' + + +@pytest.fixture +def mock_tumblr_transfer_service(verticals=[Vertical.FeedPost]): + return TumblrTransferService( + 'fake_client_id', 'fake_client_secret', 'https://redirect_uri', None, verticals + ) + + +@pytest.fixture +def mock_strava_transfer_service(verticals=[Vertical.FeedPost]): + return StravaTransferService( + 'fake_client_id', 'fake_client_secret', 'https://redirect_uri', None, verticals + ) diff --git a/tests/test_transfer_services/test_base.py b/tests/test_transfer_services/test_base.py index 741e579..ee35035 100644 --- a/tests/test_transfer_services/test_base.py +++ b/tests/test_transfer_services/test_base.py @@ -1,3 +1,5 @@ +from urllib import parse + import pytest from pardner.services import ( @@ -11,44 +13,43 @@ class FakeTransferService(BaseTransferService): - def __init__(self, supported_verticals, verticals): - super().__init__('Fake Transfer Service', '', '', '', set()) - self._supported_verticals = set(supported_verticals) - self._verticals = set(verticals) + _authorization_url = 'https://auth_url' + _token_url = 'https://token_url' - def authorization_url(self): - pass - - def fetch_token(self, code=None, authorization_response=None): - pass + def __init__(self, supported_verticals, verticals): + super().__init__( + 'Fake Transfer Service', + 'fake_client_id', + 'fake_client_secret', + 'https://redirect_uri', + set(supported_verticals), + None, + set(verticals), + ) def scope_for_verticals(self, verticals): + if Vertical.NEW_VERTICAL_EXTRA_SCOPE in verticals: + return sample_scope | {'extra_scope'} return sample_scope -@pytest.fixture -def mock_vertical(monkeypatch): - Vertical.NEW_VERTICAL = 'new_vertical' - Vertical.NEW_VERTICAL_EXTRA_SCOPE = 'new_vertical_unsupported' - - @pytest.fixture def blank_transfer_service(monkeypatch): - return FakeTransferService([], []) + return FakeTransferService([Vertical.FeedPost], []) def test_add_verticals_raises_exception(mock_vertical, blank_transfer_service): with pytest.raises(InsufficientScopeException): - blank_transfer_service.add_verticals([Vertical.FeedPost]) + blank_transfer_service.add_verticals([Vertical.NEW_VERTICAL_EXTRA_SCOPE]) def test_set_verticals_raises_exception(mock_vertical, blank_transfer_service): with pytest.raises(UnsupportedVerticalException): - blank_transfer_service.verticals = [Vertical.FeedPost] + blank_transfer_service.verticals = [Vertical.NEW_VERTICAL] @pytest.fixture -def transfer_service(mock_vertical): +def mock_transfer_service(mock_vertical): mock_transfer_service = FakeTransferService( [Vertical.FeedPost, Vertical.NEW_VERTICAL, Vertical.NEW_VERTICAL_EXTRA_SCOPE], [Vertical.FeedPost], @@ -57,34 +58,60 @@ def transfer_service(mock_vertical): return mock_transfer_service -def test_set_supported_verticals(mock_vertical, transfer_service): - transfer_service.verticals = [Vertical.NEW_VERTICAL] - assert transfer_service.verticals == {Vertical.NEW_VERTICAL} +def test_set_supported_verticals(mock_vertical, mock_transfer_service): + mock_transfer_service.verticals = [Vertical.NEW_VERTICAL] + assert mock_transfer_service.verticals == {Vertical.NEW_VERTICAL} -def test_add_supported_verticals(mock_vertical, transfer_service): - assert transfer_service.add_verticals([Vertical.NEW_VERTICAL]) - assert transfer_service.verticals == {Vertical.FeedPost, Vertical.NEW_VERTICAL} +def test_add_supported_verticals(mock_vertical, mock_transfer_service): + assert mock_transfer_service.add_verticals([Vertical.NEW_VERTICAL]) + assert mock_transfer_service.verticals == {Vertical.FeedPost, Vertical.NEW_VERTICAL} def test_add_unsupported_vertical_new_scope_required( - monkeypatch, mock_vertical, transfer_service + monkeypatch, mock_vertical, mock_transfer_service ): def _mock_scope_for_verticals(verticals): if Vertical.NEW_VERTICAL_EXTRA_SCOPE in verticals: return {'new_scope'} return sample_scope - transfer_service._oAuth2Session.access_token = 'access_token' + mock_transfer_service._oAuth2Session.access_token = 'access_token' monkeypatch.setattr( - transfer_service, 'scope_for_verticals', _mock_scope_for_verticals + mock_transfer_service, 'scope_for_verticals', _mock_scope_for_verticals ) - assert not transfer_service.add_verticals( + assert not mock_transfer_service.add_verticals( [Vertical.NEW_VERTICAL_EXTRA_SCOPE], should_reauth=True ) - assert not transfer_service._oAuth2Session.access_token - assert transfer_service.scope == {'fake', 'scope', 'new_scope'} - assert transfer_service.verticals == { + assert not mock_transfer_service._oAuth2Session.access_token + assert mock_transfer_service.scope == {'fake', 'scope', 'new_scope'} + assert mock_transfer_service.verticals == { Vertical.FeedPost, Vertical.NEW_VERTICAL_EXTRA_SCOPE, } + + +def test_authorization_url(mock_transfer_service): + auth_url, state = mock_transfer_service.authorization_url() + + auth_url_query = parse.urlsplit(auth_url).query + auth_url_params = dict(parse.parse_qsl(auth_url_query)) + + assert 'client_id' in auth_url_params + assert auth_url_params['client_id'] == 'fake_client_id' + assert 'redirect_uri' in auth_url_params + assert auth_url_params['redirect_uri'] == 'https://redirect_uri' + assert 'state' in auth_url_params + assert auth_url_params['state'] == state + + +def test_fetch_token_raises_error(mock_transfer_service): + with pytest.raises(ValueError): + mock_transfer_service.fetch_token() + + +def test_fetch_token(mock_oAuth2Session, mock_strava_transfer_service): + [mock_request, mock_response] = mock_oAuth2Session + mock_strava_transfer_service.fetch_token(code='123code123') + mock_request.assert_called_once() + mock_response.assert_called_once() diff --git a/tests/test_transfer_services/test_strava.py b/tests/test_transfer_services/test_strava.py new file mode 100644 index 0000000..78fa72d --- /dev/null +++ b/tests/test_transfer_services/test_strava.py @@ -0,0 +1,23 @@ +import pytest + +from pardner.services.base import UnsupportedVerticalException +from pardner.verticals import Vertical + +sample_scope = {'fake', 'scope'} + + +def test_scope(mock_strava_transfer_service): + mock_strava_transfer_service.scope == 'activity:read,profile:read_all' + + +@pytest.mark.parametrize( + ['verticals', 'expected_scope'], + [([], set()), ([Vertical.FeedPost], {'activity:read', 'profile:read_all'})], +) +def test_scope_for_verticals(mock_strava_transfer_service, verticals, expected_scope): + assert mock_strava_transfer_service.scope_for_verticals(verticals) == expected_scope + + +def test_scope_for_verticals_raises_error(mock_strava_transfer_service, mock_vertical): + with pytest.raises(UnsupportedVerticalException): + mock_strava_transfer_service.scope_for_verticals([Vertical.NEW_VERTICAL]) diff --git a/tests/test_transfer_services/test_transfer_services_common.py b/tests/test_transfer_services/test_transfer_services_common.py new file mode 100644 index 0000000..5445a57 --- /dev/null +++ b/tests/test_transfer_services/test_transfer_services_common.py @@ -0,0 +1,14 @@ +import pytest + + +@pytest.mark.parametrize( + 'mock_transfer_service_name', + ['mock_tumblr_transfer_service', 'mock_strava_transfer_service'], +) +def test_fetch_token(mock_oAuth2Session, request, mock_transfer_service_name): + [mock_request, _] = mock_oAuth2Session + mock_transfer_service = request.getfixturevalue(mock_transfer_service_name) + mock_transfer_service.fetch_token(code='123code123') + mock_request.assert_called_once() + assert 'client_id' in mock_request.call_args.kwargs['data'] + assert mock_request.call_args.kwargs['data']['client_id'] == 'fake_client_id' diff --git a/tests/test_transfer_services/test_tumblr.py b/tests/test_transfer_services/test_tumblr.py index cdea3cb..a971132 100644 --- a/tests/test_transfer_services/test_tumblr.py +++ b/tests/test_transfer_services/test_tumblr.py @@ -1,51 +1,12 @@ -from urllib import parse - import pytest -from pardner.services import TumblrTransferService from pardner.verticals import Vertical sample_scope = {'fake', 'scope'} -@pytest.fixture -def mock_tumblr_transfer_service(monkeypatch, verticals=[Vertical.FeedPost]): - return TumblrTransferService( - 'fake_client_id', 'fake_client_secret', 'https://redirect_uri', verticals - ) - - @pytest.mark.parametrize( ['verticals', 'expected_scope'], [([], {'base'}), ([Vertical.FeedPost, {'base'}])] ) def test_scope_for_vertical(mock_tumblr_transfer_service, verticals, expected_scope): assert mock_tumblr_transfer_service.scope_for_verticals(verticals) == expected_scope - - -def test_authorization_url(mock_tumblr_transfer_service): - auth_url, state = mock_tumblr_transfer_service.authorization_url() - - auth_url_query = parse.urlsplit(auth_url).query - auth_url_params = dict(parse.parse_qsl(auth_url_query)) - - assert 'client_id' in auth_url_params - assert auth_url_params['client_id'] == 'fake_client_id' - assert 'redirect_uri' in auth_url_params - assert auth_url_params['redirect_uri'] == 'https://redirect_uri' - assert 'state' in auth_url_params - assert auth_url_params['state'] == state - - -def test_fetch_token_raises_error(mock_tumblr_transfer_service): - with pytest.raises(ValueError): - mock_tumblr_transfer_service.fetch_token() - - -def test_fetch_token(mocker, mock_tumblr_transfer_service): - mock_oauth2session_request = mocker.patch('requests_oauthlib.OAuth2Session.request') - mock_client_parse_request_body_response = mocker.patch( - 'oauthlib.oauth2.rfc6749.clients.WebApplicationClient.parse_request_body_response' - ) - mock_tumblr_transfer_service.fetch_token(code='123code123') - mock_oauth2session_request.assert_called_once() - mock_client_parse_request_body_response.assert_called_once() diff --git a/tests/test_transfer_services/test_utils.py b/tests/test_transfer_services/test_utils.py new file mode 100644 index 0000000..237c5fe --- /dev/null +++ b/tests/test_transfer_services/test_utils.py @@ -0,0 +1,34 @@ +import pytest + +from pardner.services.utils import scope_as_set, scope_as_string + + +@pytest.mark.parametrize( + ['scopes', 'delimiter', 'expected', 'expected_with_delimiter'], + [ + ({'first', 'second'}, '+', 'first second', 'first+second'), + ('first+second', '+', 'first+second', 'first+second'), + ], +) +def test_scope_as_string(scopes, delimiter, expected, expected_with_delimiter): + assert scope_as_string(scopes) == expected + assert scope_as_string(scopes, delimiter) == expected_with_delimiter + + +def test_scope_as_string_raises_error(): + with pytest.raises(ValueError): + scope_as_string(100) + + +@pytest.mark.parametrize( + ['scope', 'delimiter', 'expected', 'expected_with_delimiter'], + [ + ('first+second', '+', {'first+second'}, {'first', 'second'}), + ({'first', 'second'}, '+', {'first', 'second'}, {'first', 'second'}), + ([], '--', set(), set()), + (None, '--', set(), set()), + ], +) +def test_scope_as_set(scope, delimiter, expected, expected_with_delimiter): + assert scope_as_set(scope) == expected + assert scope_as_set(scope, delimiter) == expected_with_delimiter