add anotation type in some lister file

This commit is contained in:
Gautier Pugnonblanc Yann 2020-02-17 15:58:34 +01:00
parent 73a33d9224
commit 60adc424be
12 changed files with 148 additions and 85 deletions

View file

@ -7,9 +7,9 @@ import logging
import iso8601
from datetime import datetime, timezone
from typing import Any
from typing import Any, Dict, List, Optional, Union
from urllib import parse
from requests import Response
from swh.lister.bitbucket.models import BitBucketModel
from swh.lister.core.indexing_lister import IndexingHttpLister
@ -26,14 +26,15 @@ class BitBucketLister(IndexingHttpLister):
instance = 'bitbucket'
default_min_bound = datetime.fromtimestamp(0, timezone.utc) # type: Any
def __init__(self, url=None, override_config=None, per_page=100):
def __init__(self, url: str = None,
override_config=None, per_page: int = 100) -> None:
super().__init__(url=url, override_config=override_config)
per_page = self.config.get('per_page', per_page)
self.PATH_TEMPLATE = '%s&pagelen=%s' % (
self.PATH_TEMPLATE, per_page)
def get_model_from_repo(self, repo):
def get_model_from_repo(self, repo: Dict) -> Dict[str, Any]:
return {
'uid': repo['uuid'],
'indexable': iso8601.parse_date(repo['created_on']),
@ -44,7 +45,8 @@ class BitBucketLister(IndexingHttpLister):
'origin_type': repo['scm'],
}
def get_next_target_from_response(self, response):
def get_next_target_from_response(self, response: Response
) -> Union[None, datetime]:
"""This will read the 'next' link from the api response if any
and return it as a datetime.
@ -60,21 +62,24 @@ class BitBucketLister(IndexingHttpLister):
if next_ is not None:
next_ = parse.urlparse(next_)
return iso8601.parse_date(parse.parse_qs(next_.query)['after'][0])
return None
def transport_response_simplified(self, response):
def transport_response_simplified(self, response: Response
) -> List[Dict[str, Any]]:
repos = response.json()['values']
return [self.get_model_from_repo(repo) for repo in repos]
def request_uri(self, identifier):
identifier = parse.quote(identifier.isoformat())
return super().request_uri(identifier or '1970-01-01')
def request_uri(self, identifier: datetime) -> str:
identifier_str = parse.quote(identifier.isoformat())
return super().request_uri(identifier_str or '1970-01-01')
def is_within_bounds(self, inner, lower=None, upper=None):
def is_within_bounds(self, inner: int, lower: Optional[int] = None,
upper: Optional[int] = None) -> bool:
# values are expected to be datetimes
if lower is None and upper is None:
ret = True
elif lower is None:
ret = inner <= upper
ret = inner <= upper # type: ignore
elif upper is None:
ret = inner >= lower
else:

View file

@ -8,8 +8,9 @@ from urllib.parse import urlparse, urljoin
from bs4 import BeautifulSoup
from requests import Session
# from requests.structures import CaseInsensitiveDict
from requests.adapters import HTTPAdapter
from typing import Any, Dict, Generator, Union
from .models import CGitModel
from swh.core.utils import grouper
@ -54,13 +55,14 @@ class CGitLister(ListerBase):
LISTER_NAME = 'cgit'
url_prefix_present = True
def __init__(self, url=None, instance=None, override_config=None):
def __init__(self, url=None, instance=None,
override_config=None):
"""Lister class for CGit repositories.
Args:
url (str): main URL of the CGit instance, i.e. url of the index
url : main URL of the CGit instance, i.e. url of the index
of published git repositories on this instance.
instance (str): Name of cgit instance. Defaults to url's hostname
instance : Name of cgit instance. Defaults to url's hostname
if unset.
"""
@ -79,7 +81,7 @@ class CGitLister(ListerBase):
'User-Agent': USER_AGENT,
}
def run(self):
def run(self) -> Dict[str, str]:
status = 'uneventful'
total = 0
for repos in grouper(self.get_repos(), 10):
@ -94,7 +96,7 @@ class CGitLister(ListerBase):
return {'status': status}
def get_repos(self):
def get_repos(self) -> Generator:
"""Generate git 'project' URLs found on the current CGit server
"""
@ -116,7 +118,7 @@ class CGitLister(ListerBase):
# no pager, or no next page
next_page = None
def build_model(self, repo_url):
def build_model(self, repo_url: str) -> Union[None, Dict[str, Any]]:
"""Given the URL of a git repo project page on a CGit server,
return the repo description (dict) suitable for insertion in the db.
"""
@ -124,7 +126,7 @@ class CGitLister(ListerBase):
urls = [x['href'] for x in bs.find_all('a', {'rel': 'vcs-git'})]
if not urls:
return
return None
# look for the http/https url, if any, and use it as origin_url
for url in urls:
@ -142,7 +144,7 @@ class CGitLister(ListerBase):
'origin_url': origin_url,
}
def get_and_parse(self, url):
def get_and_parse(self, url: str) -> BeautifulSoup:
"Get the given url and parse the retrieved HTML using BeautifulSoup"
return BeautifulSoup(self.session.get(url).text,
features='html.parser')

View file

@ -12,6 +12,9 @@ from sqlalchemy import func
from .lister_transports import ListerHttpTransport
from .lister_base import ListerBase
from requests import Response
from typing import Any, Dict, List, Tuple, Optional
logger = logging.getLogger(__name__)
@ -55,7 +58,7 @@ class IndexingLister(ListerBase):
"""
@abc.abstractmethod
def get_next_target_from_response(self, response):
def get_next_target_from_response(self, response: Response):
"""Find the next server endpoint identifier given the entire response.
Implementation of this method depends on the server API spec
@ -71,7 +74,8 @@ class IndexingLister(ListerBase):
# You probably don't need to override anything below this line.
def filter_before_inject(self, models_list):
def filter_before_inject(
self, models_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Overrides ListerBase.filter_before_inject
Bounds query results by this Lister's set max_index.
@ -100,7 +104,9 @@ class IndexingLister(ListerBase):
retlist = retlist.filter(self.MODEL.indexable <= end)
return retlist
def db_partition_indices(self, partition_size):
def db_partition_indices(
self, partition_size: int
) -> List[Tuple[Optional[int], Optional[int]]]:
"""Describe an index-space compartmentalization of the db table
in equal sized chunks. This is used to describe min&max bounds for
parallelizing fetch tasks.
@ -165,6 +171,7 @@ class IndexingLister(ListerBase):
t = self.db_session.query(func.min(self.MODEL.indexable)).first()
if t:
return t[0]
return None
def db_last_index(self):
"""Look in the db for the largest indexable value
@ -175,8 +182,10 @@ class IndexingLister(ListerBase):
t = self.db_session.query(func.max(self.MODEL.indexable)).first()
if t:
return t[0]
return None
def disable_deleted_repo_tasks(self, start, end, keep_these):
def disable_deleted_repo_tasks(
self, start, end, keep_these):
"""Disable tasks for repos that no longer exist between start and end.
Args:
@ -254,6 +263,7 @@ class IndexingLister(ListerBase):
class IndexingHttpLister(ListerHttpTransport, IndexingLister):
"""Convenience class for ensuring right lookup and init order
when combining IndexingLister and ListerHttpTransport."""
def __init__(self, url=None, override_config=None):
IndexingLister.__init__(self, override_config=override_config)
ListerHttpTransport.__init__(self, url=url)

View file

@ -13,7 +13,7 @@ import time
from sqlalchemy import create_engine, func
from sqlalchemy.orm import sessionmaker
from typing import Any, Dict, List, Type, Union
from typing import Any, Dict, List, Type, Union, Optional
from swh.core import config
from swh.core.utils import grouper
@ -21,6 +21,7 @@ from swh.scheduler import get_scheduler, utils
from .abstractattribute import AbstractAttribute
from requests import Response
logger = logging.getLogger(__name__)
@ -137,7 +138,8 @@ class ListerBase(abc.ABC, config.SWHConfig):
"""
pass
def filter_before_inject(self, models_list):
def filter_before_inject(
self, models_list: List[Dict]) -> List[Dict]:
"""Filter models_list entries prior to injection in the db.
This is ran directly after `transport_response_simplified`.
@ -152,7 +154,8 @@ class ListerBase(abc.ABC, config.SWHConfig):
"""
return models_list
def do_additional_checks(self, models_list):
def do_additional_checks(
self, models_list: List[Dict]) -> List[Dict]:
"""Execute some additional checks on the model list (after the
filtering).
@ -169,7 +172,9 @@ class ListerBase(abc.ABC, config.SWHConfig):
"""
return models_list
def is_within_bounds(self, inner, lower=None, upper=None):
def is_within_bounds(
self, inner: int,
lower: Optional[int] = None, upper: Optional[int] = None) -> bool:
"""See if a sortable value is inside the range [lower,upper].
MAY BE OVERRIDDEN, for example if the server indexable* key is
@ -188,7 +193,7 @@ class ListerBase(abc.ABC, config.SWHConfig):
if lower is None and upper is None:
return True
elif lower is None:
ret = inner <= upper
ret = inner <= upper # type: ignore
elif upper is None:
ret = inner >= lower
else:
@ -262,13 +267,13 @@ class ListerBase(abc.ABC, config.SWHConfig):
"""Reset exponential backoff timeout to initial level."""
self.backoff = self.INITIAL_BACKOFF
def back_off(self):
def back_off(self) -> int:
"""Get next exponential backoff timeout."""
ret = self.backoff
self.backoff *= 10
return ret
def safely_issue_request(self, identifier):
def safely_issue_request(self, identifier: int) -> Optional[Response]:
"""Make network request with retries, rate quotas, and response logs.
Protocol is handled by the implementation of the transport_request
@ -315,7 +320,7 @@ class ListerBase(abc.ABC, config.SWHConfig):
return r
def db_query_equal(self, key, value):
def db_query_equal(self, key: Any, value: Any):
"""Look in the db for a row with key == value
Args:
@ -419,7 +424,8 @@ class ListerBase(abc.ABC, config.SWHConfig):
'[a-zA-Z0-9]',
re.escape(a))
if (isinstance(b, str) and (re.match(a_pattern, b) is None)
or isinstance(c, str) and (re.match(a_pattern, c) is None)):
or isinstance(c, str) and
(re.match(a_pattern, c) is None)):
logger.debug(a_pattern)
raise TypeError('incomparable string patterns detected')
@ -481,7 +487,7 @@ class ListerBase(abc.ABC, config.SWHConfig):
ir, m, _ = tasks[_task_key(task)]
ir.task_id = task['id']
def ingest_data(self, identifier, checks=False):
def ingest_data(self, identifier: int, checks: bool = False):
"""The core data fetch sequence. Request server endpoint. Simplify and
filter response list of repositories. Inject repo information into
local db. Queue loader tasks for linked repositories.

View file

@ -12,7 +12,8 @@ import logging
import requests
import xmltodict
from typing import Optional, Union
from typing import Optional, Union, Dict, Any
from requests import Response
from swh.lister import USER_AGENT_TEMPLATE, __version__
@ -39,7 +40,7 @@ class ListerHttpTransport(abc.ABC):
EXPECTED_STATUS_CODES = (200, 429, 403, 404)
def request_headers(self):
def request_headers(self) -> Dict[str, Any]:
"""Returns dictionary of any request headers needed by the server.
MAY BE OVERRIDDEN if request headers are needed.
@ -97,7 +98,7 @@ class ListerHttpTransport(abc.ABC):
path = self.PATH_TEMPLATE % identifier
return self.url + path
def request_params(self, identifier):
def request_params(self, identifier: int) -> Dict[str, Any]:
"""Get the full parameters passed to requests given the
transport_request identifier.
@ -115,7 +116,8 @@ class ListerHttpTransport(abc.ABC):
return params
auth = random.choice(creds) if creds else None
if auth:
params['auth'] = (auth['username'], auth['password'])
params['auth'] = (auth['username'], # type: ignore
auth['password'])
return params
def transport_quota_check(self, response):
@ -152,7 +154,8 @@ class ListerHttpTransport(abc.ABC):
self.session = requests.Session()
self.lister_version = __version__
def _transport_action(self, identifier, method='get'):
def _transport_action(
self, identifier: int, method: str = 'get') -> Response:
"""Permit to ask information to the api prior to actually executing
query.
@ -176,13 +179,13 @@ class ListerHttpTransport(abc.ABC):
raise FetchError(response)
return response
def transport_head(self, identifier):
def transport_head(self, identifier: int) -> Response:
"""Retrieve head information on api.
"""
return self._transport_action(identifier, method='head')
def transport_request(self, identifier):
def transport_request(self, identifier: int) -> Response:
"""Implements ListerBase.transport_request for HTTP using Requests.
Retrieve get information on api.
@ -190,7 +193,7 @@ class ListerHttpTransport(abc.ABC):
"""
return self._transport_action(identifier)
def transport_response_to_string(self, response):
def transport_response_to_string(self, response: Response) -> str:
"""Implements ListerBase.transport_response_to_string for HTTP given
Requests responses.
"""

View file

@ -13,7 +13,8 @@ import logging
from debian.deb822 import Sources
from sqlalchemy.orm import joinedload, load_only
from sqlalchemy.schema import CreateTable, DropTable
from typing import Mapping, Optional
from typing import Mapping, Optional, Dict, Any
from requests import Response
from swh.lister.debian.models import (
AreaSnapshot, Distribution, DistributionSnapshot, Package,
@ -58,7 +59,7 @@ class DebianLister(ListerHttpTransport, ListerBase):
self.date = override_config.get('date', date) or datetime.datetime.now(
tz=datetime.timezone.utc)
def transport_request(self, identifier):
def transport_request(self, identifier) -> Response:
"""Subvert ListerHttpTransport.transport_request, to try several
index URIs in turn.
@ -94,7 +95,7 @@ class DebianLister(ListerHttpTransport, ListerBase):
# need to return it here.
return identifier
def request_params(self, identifier):
def request_params(self, identifier) -> Dict[str, Any]:
# Enable streaming to allow wrapping the response in the decompressor
# in transport_response_simplified.
params = super().request_params(identifier)

View file

@ -5,11 +5,13 @@
import re
from typing import Any
from typing import Any, Dict, List, Tuple, Optional
from swh.lister.core.indexing_lister import IndexingHttpLister
from swh.lister.github.models import GitHubModel
from requests import Response
class GitHubLister(IndexingHttpLister):
PATH_TEMPLATE = '/repositories?since=%d'
@ -20,7 +22,7 @@ class GitHubLister(IndexingHttpLister):
instance = 'github' # There is only 1 instance of such lister
default_min_bound = 0 # type: Any
def get_model_from_repo(self, repo):
def get_model_from_repo(self, repo: Dict[str, Any]) -> Dict[str, Any]:
return {
'uid': repo['id'],
'indexable': repo['id'],
@ -32,7 +34,7 @@ class GitHubLister(IndexingHttpLister):
'fork': repo['fork'],
}
def transport_quota_check(self, response):
def transport_quota_check(self, response: Response) -> Tuple[bool, int]:
x_rate_limit_remaining = response.headers.get('X-RateLimit-Remaining')
if not x_rate_limit_remaining:
return False, 0
@ -42,17 +44,21 @@ class GitHubLister(IndexingHttpLister):
return True, delay
return False, 0
def get_next_target_from_response(self, response):
def get_next_target_from_response(self,
response: Response) -> Optional[int]:
if 'next' in response.links:
next_url = response.links['next']['url']
return int(self.API_URL_INDEX_RE.match(next_url).group(1))
return int(
self.API_URL_INDEX_RE.match(next_url).group(1)) # type: ignore
return None
def transport_response_simplified(self, response):
def transport_response_simplified(self, response: Response
) -> List[Dict[str, Any]]:
repos = response.json()
return [self.get_model_from_repo(repo)
for repo in repos if repo and 'id' in repo]
def request_headers(self):
def request_headers(self) -> Dict[str, Any]:
"""(Override) Set requests headers to send when querying the GitHub API
"""
@ -60,7 +66,8 @@ class GitHubLister(IndexingHttpLister):
headers['Accept'] = 'application/vnd.github.v3+json'
return headers
def disable_deleted_repo_tasks(self, index, next_index, keep_these):
def disable_deleted_repo_tasks(self, index: int,
next_index: int, keep_these: int):
""" (Overrides) Fix provided index value to avoid erroneously disabling
some scheduler tasks
"""

View file

@ -9,6 +9,9 @@ from urllib3.util import parse_url
from ..core.page_by_page_lister import PageByPageHttpLister
from .models import GitLabModel
from typing import Any, Dict, List, Tuple, Union, MutableMapping, Optional
from requests import Response
class GitLabLister(PageByPageHttpLister):
# Template path expecting an integer that represents the page id
@ -26,10 +29,10 @@ class GitLabLister(PageByPageHttpLister):
self.PATH_TEMPLATE = '%s&sort=%s&per_page=%s' % (
self.PATH_TEMPLATE, sort, per_page)
def uid(self, repo):
def uid(self, repo: Dict[str, Any]) -> str:
return '%s/%s' % (self.instance, repo['path_with_namespace'])
def get_model_from_repo(self, repo):
def get_model_from_repo(self, repo: Dict[str, Any]) -> Dict[str, Any]:
return {
'instance': self.instance,
'uid': self.uid(repo),
@ -40,7 +43,8 @@ class GitLabLister(PageByPageHttpLister):
'origin_type': 'git',
}
def transport_quota_check(self, response):
def transport_quota_check(self, response: Response
) -> Tuple[bool, Union[int, float]]:
"""Deal with rate limit if any.
"""
@ -53,18 +57,22 @@ class GitLabLister(PageByPageHttpLister):
return True, delay
return False, 0
def _get_int(self, headers, key):
def _get_int(self, headers: MutableMapping[str, Any],
key: str) -> Optional[int]:
_val = headers.get(key)
if _val:
return int(_val)
return None
def get_next_target_from_response(self, response):
def get_next_target_from_response(
self, response: Response) -> Optional[int]:
"""Determine the next page identifier.
"""
return self._get_int(response.headers, 'x-next-page')
def get_pages_information(self):
def get_pages_information(self) -> Tuple[Optional[int],
Optional[int], Optional[int]]:
"""Determine pages information.
"""
@ -77,6 +85,7 @@ class GitLabLister(PageByPageHttpLister):
self._get_int(h, 'x-total-pages'),
self._get_int(h, 'x-per-page'))
def transport_response_simplified(self, response):
def transport_response_simplified(self, response: Response
) -> List[Dict[str, Any]]:
repos = response.json()
return [self.get_model_from_repo(repo) for repo in repos]

View file

@ -10,6 +10,8 @@ from swh.lister.core.simple_lister import SimpleLister
from swh.lister.gnu.models import GNUModel
from swh.lister.gnu.tree import GNUTree
from typing import Any, Dict, List
from requests import Response
logger = logging.getLogger(__name__)
@ -58,7 +60,7 @@ class GNULister(SimpleLister):
retries_left=3,
)
def safely_issue_request(self, identifier):
def safely_issue_request(self, identifier: int) -> None:
"""Bypass the implementation. It's now the GNUTree which deals with
querying the gnu mirror.
@ -69,7 +71,7 @@ class GNULister(SimpleLister):
"""
return None
def list_packages(self, response):
def list_packages(self, response: Response) -> List[Dict[str, Any]]:
"""List the actual gnu origins (package name) with their name, url and
associated tarballs.
@ -96,7 +98,7 @@ class GNULister(SimpleLister):
"""
return list(self.gnu_tree.projects.values())
def get_model_from_repo(self, repo):
def get_model_from_repo(self, repo: Dict[str, Any]) -> Dict[str, Any]:
"""Transform from repository representation to model
"""

View file

@ -6,6 +6,9 @@ from swh.lister.core.indexing_lister import IndexingHttpLister
from swh.lister.npm.models import NpmModel
from swh.scheduler.utils import create_task_dict
from typing import Any, Dict, Optional, List
from requests import Response
class NpmListerBase(IndexingHttpLister):
"""List packages available in the npm registry in a paginated way
@ -22,7 +25,7 @@ class NpmListerBase(IndexingHttpLister):
self.PATH_TEMPLATE += '&limit=%s' % self.per_page
@property
def ADDITIONAL_CONFIG(self):
def ADDITIONAL_CONFIG(self) -> Dict[str, Any]:
"""(Override) Add extra configuration
"""
@ -30,7 +33,7 @@ class NpmListerBase(IndexingHttpLister):
default_config['loading_task_policy'] = ('str', 'recurring')
return default_config
def get_model_from_repo(self, repo_name):
def get_model_from_repo(self, repo_name: str) -> Dict[str, str]:
"""(Override) Transform from npm package name to model
"""
@ -45,7 +48,7 @@ class NpmListerBase(IndexingHttpLister):
'origin_type': 'npm',
}
def task_dict(self, origin_type, origin_url, **kwargs):
def task_dict(self, origin_type: str, origin_url: str, **kwargs):
"""(Override) Return task dict for loading a npm package into the
archive.
@ -58,7 +61,7 @@ class NpmListerBase(IndexingHttpLister):
return create_task_dict(task_type, task_policy,
url=origin_url)
def request_headers(self):
def request_headers(self) -> Dict[str, Any]:
"""(Override) Set requests headers to send when querying the npm
registry.
@ -67,7 +70,7 @@ class NpmListerBase(IndexingHttpLister):
headers['Accept'] = 'application/json'
return headers
def string_pattern_check(self, inner, lower, upper=None):
def string_pattern_check(self, inner: int, lower: int, upper: int = None):
""" (Override) Inhibit the effect of that method as packages indices
correspond to package names and thus do not respect any kind
of fixed length string pattern
@ -82,14 +85,16 @@ class NpmLister(NpmListerBase):
"""
PATH_TEMPLATE = '/_all_docs?startkey="%s"'
def get_next_target_from_response(self, response):
def get_next_target_from_response(
self, response: Response) -> Optional[str]:
"""(Override) Get next npm package name to continue the listing
"""
repos = response.json()['rows']
return repos[-1]['id'] if len(repos) == self.per_page else None
def transport_response_simplified(self, response):
def transport_response_simplified(
self, response: Response) -> List[Dict[str, str]]:
"""(Override) Transform npm registry response to list for model manipulation
"""
@ -110,14 +115,16 @@ class NpmIncrementalLister(NpmListerBase):
def CONFIG_BASE_FILENAME(self): # noqa: N802
return 'lister_npm_incremental'
def get_next_target_from_response(self, response):
def get_next_target_from_response(
self, response: Response) -> Optional[str]:
"""(Override) Get next npm package name to continue the listing.
"""
repos = response.json()['results']
return repos[-1]['seq'] if len(repos) == self.per_page else None
def transport_response_simplified(self, response):
def transport_response_simplified(
self, response: Response) -> List[Dict[str, str]]:
"""(Override) Transform npm registry response to list for model
manipulation.
@ -127,7 +134,7 @@ class NpmIncrementalLister(NpmListerBase):
repos = repos[:-1]
return [self.get_model_from_repo(repo['id']) for repo in repos]
def filter_before_inject(self, models_list):
def filter_before_inject(self, models_list: List[Dict[str, Any]]):
"""(Override) Filter out documents in the CouchDB database
not related to a npm package.

View file

@ -14,6 +14,8 @@ from sqlalchemy import func
from swh.lister.core.indexing_lister import IndexingHttpLister
from swh.lister.phabricator.models import PhabricatorModel
from typing import Any, Dict, List, Optional
from requests import Response
logger = logging.getLogger(__name__)
@ -31,7 +33,7 @@ class PhabricatorLister(IndexingHttpLister):
instance = urllib.parse.urlparse(self.url).hostname
self.instance = instance
def request_params(self, identifier):
def request_params(self, identifier: int) -> Dict[str, Any]:
"""Override the default params behavior to retrieve the api token
Credentials are stored as:
@ -61,7 +63,8 @@ class PhabricatorLister(IndexingHttpLister):
headers['Accept'] = 'application/json'
return headers
def get_model_from_repo(self, repo):
def get_model_from_repo(
self, repo: Dict[str, Any]) -> Optional[Dict[str, Any]]:
url = get_repo_url(repo['attachments']['uris']['uris'])
if url is None:
return None
@ -76,12 +79,15 @@ class PhabricatorLister(IndexingHttpLister):
'instance': self.instance,
}
def get_next_target_from_response(self, response):
def get_next_target_from_response(
self, response: Response) -> Optional[int]:
body = response.json()['result']['cursor']
if body['after'] and body['after'] != 'null':
return int(body['after'])
return None
def transport_response_simplified(self, response):
def transport_response_simplified(
self, response: Response) -> List[Optional[Dict[str, Any]]]:
repos = response.json()
if repos['result'] is None:
raise ValueError(
@ -97,7 +103,8 @@ class PhabricatorLister(IndexingHttpLister):
models_list = [m for m in models_list if m is not None]
return super().filter_before_inject(models_list)
def disable_deleted_repo_tasks(self, index, next_index, keep_these):
def disable_deleted_repo_tasks(
self, index: int, next_index: int, keep_these: str):
"""
(Overrides) Fix provided index value to avoid:
@ -117,7 +124,7 @@ class PhabricatorLister(IndexingHttpLister):
return super().disable_deleted_repo_tasks(index, next_index,
keep_these)
def db_first_index(self):
def db_first_index(self) -> Optional[int]:
"""
(Overrides) Filter results by Phabricator instance
@ -128,6 +135,7 @@ class PhabricatorLister(IndexingHttpLister):
t = t.filter(self.MODEL.instance == self.instance).first()
if t:
return t[0]
return None
def db_last_index(self):
"""
@ -141,7 +149,7 @@ class PhabricatorLister(IndexingHttpLister):
if t:
return t[0]
def db_query_range(self, start, end):
def db_query_range(self, start: int, end: int):
"""
(Overrides) Filter the results by the Phabricator instance to
avoid disabling loading tasks for repositories hosted on a
@ -155,14 +163,14 @@ class PhabricatorLister(IndexingHttpLister):
return retlist.filter(self.MODEL.instance == self.instance)
def get_repo_url(attachments):
def get_repo_url(attachments: List[Dict[str, Any]]) -> Optional[int]:
"""
Return url for a hosted repository from its uris attachments according
to the following priority lists:
* protocol: https > http
* identifier: shortname > callsign > id
"""
processed_urls = defaultdict(dict)
processed_urls = defaultdict(dict) # type: Dict[str, Any]
for uri in attachments:
protocol = uri['fields']['builtin']['protocol']
url = uri['fields']['uri']['effective']

View file

@ -12,6 +12,9 @@ from swh.scheduler import utils
from swh.lister.core.simple_lister import SimpleLister
from swh.lister.core.lister_transports import ListerOnePageApiTransport
from typing import Any, Dict
from requests import Response
class PyPILister(ListerOnePageApiTransport, SimpleLister):
MODEL = PyPIModel
@ -23,7 +26,7 @@ class PyPILister(ListerOnePageApiTransport, SimpleLister):
ListerOnePageApiTransport .__init__(self)
SimpleLister.__init__(self, override_config=override_config)
def task_dict(self, origin_type, origin_url, **kwargs):
def task_dict(self, origin_type: str, origin_url: str, **kwargs):
"""(Override) Return task format dict
This is overridden from the lister_base as more information is
@ -35,7 +38,7 @@ class PyPILister(ListerOnePageApiTransport, SimpleLister):
return utils.create_task_dict(
_type, _policy, url=origin_url)
def list_packages(self, response):
def list_packages(self, response: Response) -> list:
"""(Override) List the actual pypi origins from the response.
"""
@ -50,7 +53,7 @@ class PyPILister(ListerOnePageApiTransport, SimpleLister):
"""
return 'https://pypi.org/project/%s/' % repo_name
def get_model_from_repo(self, repo_name):
def get_model_from_repo(self, repo_name: str) -> Dict[str, Any]:
"""(Override) Transform from repository representation to model
"""