diff --git a/atcodertools/client/atcoder.py b/atcodertools/client/atcoder.py index aa37c767..569c588b 100644 --- a/atcodertools/client/atcoder.py +++ b/atcodertools/client/atcoder.py @@ -4,6 +4,7 @@ import warnings from http.cookiejar import LWPCookieJar from typing import List, Optional, Tuple, Union +from urllib3.util.retry import Retry import requests from bs4 import BeautifulSoup @@ -66,7 +67,13 @@ def default_credential_supplier() -> Tuple[str, str]: class AtCoderClient(metaclass=Singleton): def __init__(self): - self._session = requests.Session() + session = requests.Session() + retries = Retry(total=3, + backoff_factor=0.5, + status_forcelist=[status for status in range(400, 600)]) + session.mount("https://", + requests.adapters.HTTPAdapter(max_retries=retries)) + self._session = session def check_logging_in(self): private_url = "https://atcoder.jp/home" @@ -110,10 +117,11 @@ def login(self, save_cookie(self._session) def download_problem_list(self, contest: Contest) -> List[Problem]: - resp = self._request(contest.get_problem_list_url()) - soup = BeautifulSoup(resp.text, "html.parser") - if resp.status_code == 404: + try: + resp = self._request(contest.get_problem_list_url()) + except requests.exceptions.RetryError: raise PageNotFoundError + soup = BeautifulSoup(resp.text, "html.parser") res = [] for tag in soup.find('table').select('tr')[1::]: tag = tag.find("a") diff --git a/atcodertools/tools/envgen.py b/atcodertools/tools/envgen.py index 33ba31bd..dd314e29 100755 --- a/atcodertools/tools/envgen.py +++ b/atcodertools/tools/envgen.py @@ -5,7 +5,6 @@ import sys import traceback from multiprocessing import Pool, cpu_count -import time from typing import Tuple from colorama import Fore @@ -175,24 +174,12 @@ def func(argv: Tuple[AtCoderClient, Problem, Config]): def prepare_contest(atcoder_client: AtCoderClient, contest_id: str, - config: Config, - retry_delay_secs: float = 1.5, - retry_max_delay_secs: float = 60, - retry_max_tries: int = 10): - attempt_count = 1 - while True: - try: - problem_list = atcoder_client.download_problem_list( - Contest(contest_id=contest_id)) - break - except PageNotFoundError: - if 0 < retry_max_tries < attempt_count: - raise EnvironmentInitializationError - logger.warning( - "Failed to fetch. Will retry in {} seconds. (Attempt {})".format(retry_delay_secs, attempt_count)) - time.sleep(retry_delay_secs) - retry_delay_secs = min(retry_delay_secs * 2, retry_max_delay_secs) - attempt_count += 1 + config: Config): + try: + problem_list = atcoder_client.download_problem_list( + Contest(contest_id=contest_id)) + except PageNotFoundError: + raise EnvironmentInitializationError tasks = [(atcoder_client, problem, diff --git a/tests/test_envgen.py b/tests/test_envgen.py index 06b9a5c9..9fbc3c7b 100755 --- a/tests/test_envgen.py +++ b/tests/test_envgen.py @@ -121,9 +121,9 @@ def test_skip_existing_problems(self): self.assertDirectoriesEqual(answer_data_dir_path, self.temp_dir) - @mock.patch('time.sleep') - def test_prepare_contest_aborts_after_max_retry_attempts(self, mock_sleep): + def test_prepare_contest_aborts_after_max_retry_attempts(self): mock_client = mock.Mock(spec=AtCoderClient) + # PageNotFoundError is thrown when requests.exceptions.RetryError occurs mock_client.download_problem_list.side_effect = PageNotFoundError self.assertRaises( EnvironmentInitializationError, @@ -141,17 +141,6 @@ def test_prepare_contest_aborts_after_max_retry_attempts(self, mock_sleep): out_example_format="output_{}.txt" )) ) - self.assertEqual(mock_sleep.call_count, 10) - mock_sleep.assert_has_calls([mock.call(1.5), - mock.call(3.0), - mock.call(6.0), - mock.call(12.0), - mock.call(24.0), - mock.call(48.0), - mock.call(60.0), - mock.call(60.0), - mock.call(60.0), - mock.call(60.0)]) def assertDirectoriesEqual(self, expected_dir_path, dir_path): files1 = get_all_rel_file_paths(expected_dir_path)