From 6c8e16aae91b8c3df9cb2d48fe068576a3d3fdc7 Mon Sep 17 00:00:00 2001 From: Nicolas Dandrimont Date: Mon, 11 Sep 2017 15:23:17 +0200 Subject: [PATCH] lister_transports: allow overriding the parameters to requests --- swh/lister/core/lister_transports.py | 36 +++++++++++++++++++++------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/swh/lister/core/lister_transports.py b/swh/lister/core/lister_transports.py index 020eece..4e116f9 100644 --- a/swh/lister/core/lister_transports.py +++ b/swh/lister/core/lister_transports.py @@ -41,6 +41,30 @@ class SWHListerHttpTransport(abc.ABC): 'User-Agent': 'Software Heritage lister (%s)' % self.lister_version } + def request_uri(self, identifier): + """Get the full request URI given the transport_request identifier. + + MAY BE OVERRIDDEN if something more complex than the PATH_TEMPLATE is + required. + """ + path = self.PATH_TEMPLATE % identifier + return self.api_baseurl + path + + def request_params(self, identifier): + """Get the full parameters passed to requests given the transport_request + identifier. + + MAY BE OVERRIDDEN if something more complex than the request headers + ois needed. + """ + params = {} + params['headers'] = self.request_headers() or {} + creds = self.config['credentials'] + auth = random.choice(creds) if creds else None + if auth: + params['auth'] = (auth['username'], auth['password']) + return params + def transport_quota_check(self, response): """Implements SWHListerBase.transport_quota_check with standard 429 code check for HTTP with Requests library. @@ -73,15 +97,11 @@ class SWHListerHttpTransport(abc.ABC): def transport_request(self, identifier): """Implements SWHListerBase.transport_request for HTTP using Requests. """ - path = self.PATH_TEMPLATE % identifier - params = {} - params['headers'] = self.request_headers() or {} - creds = self.config['credentials'] - auth = random.choice(creds) if creds else None - if auth: - params['auth'] = (auth['username'], auth['password']) + path = self.request_uri(identifier) + params = self.request_params(identifier) + try: - response = self.session.get(self.api_baseurl + path, **params) + response = self.session.get(path, **params) except requests.exceptions.ConnectionError as e: raise FetchError(e) else: