from __future__ import unicode_literals import os import requests from oauthlib.common import log, generate_token, urldecode from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError from oauthlib.oauth2 import TokenExpiredError from .utils import is_secure_transport class TokenUpdated(Warning): def __init__(self, token): super(TokenUpdated, self).__init__() self.token = token class OAuth2Session(requests.Session): """Versatile OAuth 2 extension to :class:`requests.Session`. Supports any grant type adhering to :class:`oauthlib.oauth2.Client` spec including the four core OAuth 2 grants. Can be used to create authorization urls, fetch tokens and access protected resources using the :class:`requests.Session` interface you are used to. - :class:`oauthlib.oauth2.WebApplicationClient` (default): Authorization Code Grant - :class:`oauthlib.oauth2.MobileApplicationClient`: Implicit Grant - :class:`oauthlib.oauth2.LegacyApplicationClient`: Password Credentials Grant - :class:`oauthlib.oauth2.BackendApplicationClient`: Client Credentials Grant Note that the only time you will be using Implicit Grant from python is if you are driving a user agent able to obtain URL fragments. """ def __init__(self, client_id=None, client=None, auto_refresh_url=None, auto_refresh_kwargs=None, scope=None, redirect_uri=None, token=None, state=None, state_generator=None, token_updater=None, **kwargs): """Construct a new OAuth 2 client session. :param client_id: Client id obtained during registration :param client: :class:`oauthlib.oauth2.Client` to be used. Default is WebApplicationClient which is useful for any hosted application but not mobile or desktop. :param scope: List of scopes you wish to request access to :param redirect_uri: Redirect URI you registered as callback :param token: Token dictionary, must include access_token and token_type. :param state: State string used to prevent CSRF. This will be given when creating the authorization url and must be supplied when parsing the authorization response. Can be either a string or a no argument callable. :auto_refresh_url: Refresh token endpoint URL, must be HTTPS. Supply this if you wish the client to automatically refresh your access tokens. :auto_refresh_kwargs: Extra arguments to pass to the refresh token endpoint. :token_updater: Method with one argument, token, to be used to update your token databse on automatic token refresh. If not set a TokenUpdated warning will be raised when a token has been refreshed. This warning will carry the token in its token argument. :param kwargs: Arguments to pass to the Session constructor. """ super(OAuth2Session, self).__init__(**kwargs) self.client_id = client_id or client.client_id self.scope = scope self.redirect_uri = redirect_uri self.token = token or {} self.state = state or generate_token self._state = None self.auto_refresh_url = auto_refresh_url self.auto_refresh_kwargs = auto_refresh_kwargs or {} self.token_updater = token_updater self._client = client or WebApplicationClient(client_id, token=token) self._client._populate_attributes(token or {}) # Allow customizations for non compliant providers through various # hooks to adjust requests and responses. self.compliance_hook = { 'access_token_response': set([]), 'refresh_token_response': set([]), 'protected_request': set([]), } def new_state(self): """Generates a state string to be used in authorizations.""" try: self._state = self.state() log.debug('Generated new state %s.', self._state) except TypeError: self._state = self.state log.debug('Re-using previously supplied state %s.', self._state) return self._state def authorization_url(self, url, **kwargs): """Form an authorization URL. :param url: Authorization endpoint url, must be HTTPS. :param kwargs: Extra parameters to include. :return: authorization_url, state """ state = self.new_state() return self._client.prepare_request_uri(url, redirect_uri=self.redirect_uri, scope=self.scope, state=state, **kwargs), state def fetch_token(self, token_url, code=None, authorization_response=None, body='', auth=None, username=None, password=None, **kwargs): """Generic method for fetching an access token from the token endpoint. If you are using the MobileApplicationClient you will want to use token_from_fragment instead of fetch_token. :param token_url: Token endpoint URL, must use HTTPS. :param code: Authorization code (used by WebApplicationClients). :param authorization_response: Authorization response URL, the callback URL of the request back to you. Used by WebApplicationClients instead of code. :param body: Optional application/x-www-form-urlencoded body to add the include in the token request. Prefer kwargs over body. :param auth: An auth tuple or method as accepted by requests. :param username: Username used by LegacyApplicationClients. :param password: Password used by LegacyApplicationClients. :param kwargs: Extra parameters to include in the token request. :return: A token dict """ if not is_secure_transport(token_url): raise InsecureTransportError() if not code and authorization_response: self._client.parse_request_uri_response(authorization_response, state=self._state) code = self._client.code elif not code and isinstance(self._client, WebApplicationClient): code = self._client.code if not code: raise ValueError('Please supply either code or ' 'authorization_code parameters.') body = self._client.prepare_request_body(code=code, body=body, redirect_uri=self.redirect_uri, username=username, password=password, **kwargs) # (ib-lundgren) All known, to me, token requests use POST. r = self.post(token_url, data=dict(urldecode(body)), headers={'Accept': 'application/json'}, auth=auth) log.debug('Prepared fetch token request body %s', body) log.debug('Request to fetch token completed with status %s.', r.status_code) log.debug('Response headers were %s and content %s.', r.headers, r.text) log.debug('Invoking %d token response hooks.', len(self.compliance_hook['access_token_response'])) for hook in self.compliance_hook['access_token_response']: log.debug('Invoking hook %s.', hook) r = hook(r) self._client.parse_request_body_response(r.text, scope=self.scope) self.token = self._client.token log.debug('Obtained token %s.', self.token) return self.token def token_from_fragment(self, authorization_response): """Parse token from the URI fragment, used by MobileApplicationClients. :param authorization_response: The full URL of the redirect back to you :return: A token dict """ self._client.parse_request_uri_response(authorization_response, state=self._state) self.token = self._client.token return self.token def refresh_token(self, token_url, refresh_token=None, body='', auth=None, **kwargs): """Fetch a new access token using a refresh token. :param token_url: The token endpoint, must be HTTPS. :param refresh_token: The refresh_token to use. :param body: Optional application/x-www-form-urlencoded body to add the include in the token request. Prefer kwargs over body. :param auth: An auth tuple or method as accepted by requests. :param kwargs: Extra parameters to include in the token request. :return: A token dict """ if not token_url: raise ValueError('No token endpoint set for auto_refresh.') if not is_secure_transport(token_url): raise InsecureTransportError() # Need to nullify token to prevent it from being added to the request refresh_token = refresh_token or self.token.get('refresh_token') self.token = {} log.debug('Adding auto refresh key word arguments %s.', self.auto_refresh_kwargs) kwargs.update(self.auto_refresh_kwargs) body = self._client.prepare_refresh_body(body=body, refresh_token=refresh_token, scope=self.scope, **kwargs) log.debug('Prepared refresh token request body %s', body) r = self.post(token_url, data=dict(urldecode(body)), auth=auth) log.debug('Request to refresh token completed with status %s.', r.status_code) log.debug('Response headers were %s and content %s.', r.headers, r.text) log.debug('Invoking %d token response hooks.', len(self.compliance_hook['refresh_token_response'])) for hook in self.compliance_hook['refresh_token_response']: log.debug('Invoking hook %s.', hook) r = hook(r) self.token = self._client.parse_request_body_response(r.text, scope=self.scope) if not 'refresh_token' in self.token: log.debug('No new refresh token given. Re-using old.') self.token['refresh_token'] = refresh_token return self.token def request(self, method, url, data=None, headers=None, **kwargs): """Intercept all requests and add the OAuth 2 token if present.""" if not is_secure_transport(url): raise InsecureTransportError() if self.token: log.debug('Invoking %d protected resource request hooks.', len(self.compliance_hook['protected_request'])) for hook in self.compliance_hook['protected_request']: log.debug('Invoking hook %s.', hook) url, headers, data = hook(url, headers, data) log.debug('Adding token %s to request.', self.token) try: url, headers, data = self._client.add_token(url, http_method=method, body=data, headers=headers) # Attempt to retrieve and save new access token if expired except TokenExpiredError: if self.auto_refresh_url: log.debug('Auto refresh is set, attempting to refresh at %s.', self.auto_refresh_url) token = self.refresh_token(self.auto_refresh_url) if self.token_updater: log.debug('Updating token to %s using %s.', token, self.token_updater) self.token_updater(token) url, headers, data = self._client.add_token(url, http_method=method, body=data, headers=headers) else: raise TokenUpdated(token) else: raise log.debug('Requesting url %s using method %s.', url, method) log.debug('Supplying headers %s and data %s', headers, data) log.debug('Passing through key word arguments %s.', kwargs) return super(OAuth2Session, self).request(method, url, headers=headers, data=data, **kwargs) def register_compliance_hook(self, hook_type, hook): """Register a hook for request/response tweaking. Available hooks are: access_token_response invoked before token parsing. refresh_token_response invoked before refresh token parsing. protected_request invoked before making a request. If you find a new hook is needed please send a GitHub PR request or open an issue. """ if hook_type not in self.compliance_hook: raise ValueError('Hook type %s is not in %s.', hook_type, self.compliance_hook) self.compliance_hook[hook_type].add(hook)