Skip to content
This repository was archived by the owner on Sep 22, 2023. It is now read-only.

Commit 167e7f9

Browse files
committed
ci: Fix up all type errors
1 parent e73ac0d commit 167e7f9

20 files changed

+565
-380
lines changed

.github/workflows/typecheck-mypy.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ jobs:
2424
else
2525
echo "::add-matcher::.github/workflows/mypy-matcher.json"
2626
fi
27-
python -m mypy --no-color-output src/ai/backend || exit 0
27+
python -m mypy --no-color-output src/ai/backend

src/ai/backend/client/cli/admin/images.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def images(operation: bool) -> None:
3636
print_warn('There are no registered images.')
3737
return
3838
print(tabulate((item.values() for item in items),
39-
headers=(item[0] for item in fields),
39+
headers=[item[0] for item in fields],
4040
floatfmt=',.0f'))
4141

4242

src/ai/backend/client/cli/app.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class ProxyRunnerContext:
128128
protocol: str
129129
host: str
130130
port: int
131-
args: Dict[str, str]
131+
args: Dict[str, Union[None, str, List[str]]]
132132
envs: Dict[str, str]
133133
api_session: Optional[AsyncSession]
134134
local_server: Optional[asyncio.AbstractServer]
@@ -150,7 +150,7 @@ def __init__(self, host: str, port: int,
150150
self.exit_code = 0
151151

152152
self.args, self.envs = {}, {}
153-
if len(args) > 0:
153+
if args is not None and len(args) > 0:
154154
for argline in args:
155155
tokens = []
156156
for token in shlex.shlex(argline,
@@ -168,7 +168,7 @@ def __init__(self, host: str, port: int,
168168
self.args[tokens[0]] = tokens[1]
169169
else:
170170
self.args[tokens[0]] = tokens[1:]
171-
if len(envs) > 0:
171+
if envs is not None and len(envs) > 0:
172172
for envline in envs:
173173
split = envline.strip().split('=', maxsplit=2)
174174
if len(split) == 2:
@@ -178,6 +178,7 @@ def __init__(self, host: str, port: int,
178178

179179
async def handle_connection(self, reader: asyncio.StreamReader,
180180
writer: asyncio.StreamWriter) -> None:
181+
assert self.api_session is not None
181182
p = WSProxy(self.api_session, self.session_name,
182183
self.app_name, self.protocol,
183184
self.args, self.envs,
@@ -232,6 +233,7 @@ async def __aexit__(self, *exc_info) -> None:
232233
print_info("Shutting down....")
233234
self.local_server.close()
234235
await self.local_server.wait_closed()
236+
assert self.api_session is not None
235237
await self.api_session.__aexit__(*exc_info)
236238
assert self.api_session.closed
237239
if self.local_server is not None:

src/ai/backend/client/cli/proxy.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import json
35
import re
6+
from typing import (
7+
Union,
8+
Tuple,
9+
)
410

511
import aiohttp
612
from aiohttp import web
@@ -19,6 +25,8 @@ class WebSocketProxy:
1925
'upstream_buffer', 'upstream_buffer_task',
2026
)
2127

28+
upstream_buffer: asyncio.Queue[Tuple[Union[str, bytes], aiohttp.WSMsgType]]
29+
2230
def __init__(self, up_conn: aiohttp.ClientWebSocketResponse,
2331
down_conn: web.WebSocketResponse):
2432
self.up_conn = up_conn

src/ai/backend/client/cli/run.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ..compat import asyncio_run, current_loop
2323
from ..exceptions import BackendError, BackendAPIError
2424
from ..session import Session, AsyncSession, is_legacy_server
25-
from ..utils import undefined
25+
from ..types import undefined
2626
from .pretty import (
2727
print_info, print_wait, print_done, print_error, print_fail, print_warn,
2828
format_info,
@@ -881,7 +881,7 @@ def start(image, name, owner, # base args
881881
help='Set the owner of the target session explicitly.')
882882
# job scheduling options
883883
@click.option('--type', 'type_', metavar='SESSTYPE',
884-
type=click.Choice(['batch', 'interactive', undefined]),
884+
type=click.Choice(['batch', 'interactive', undefined]), # type: ignore
885885
default=undefined,
886886
help='Either batch or interactive')
887887
@click.option('-i', '--image', default=undefined,

src/ai/backend/client/config.py

+42-37
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import (
66
Any, Callable, Iterable, Union,
77
List, Tuple, Sequence,
8+
Mapping,
9+
cast,
810
)
911

1012
import appdirs
@@ -39,7 +41,7 @@ def parse_api_version(value: str) -> Tuple[int, str]:
3941

4042
def get_env(key: str, default: Any = _undefined, *,
4143
clean: Callable[[str], Any] = lambda v: v):
42-
'''
44+
"""
4345
Retrieves a configuration value from the environment variables.
4446
The given *key* is uppercased and prefixed by ``"BACKEND_"`` and then
4547
``"SORNA_"`` if the former does not exist.
@@ -52,7 +54,7 @@ def get_env(key: str, default: Any = _undefined, *,
5254
The default is returning the value as-is.
5355
5456
:returns: The value processed by the *clean* function.
55-
'''
57+
"""
5658
key = key.upper()
5759
v = os.environ.get('BACKEND_' + key)
5860
if v is None:
@@ -73,7 +75,7 @@ def bool_env(v: str) -> bool:
7375
raise ValueError('Unrecognized value of boolean environment variable', v)
7476

7577

76-
def _clean_urls(v: str) -> List[URL]:
78+
def _clean_urls(v: Union[URL, str]) -> List[URL]:
7779
if isinstance(v, URL):
7880
return [v]
7981
if isinstance(v, str):
@@ -95,7 +97,7 @@ def _clean_tokens(v):
9597

9698

9799
class APIConfig:
98-
'''
100+
"""
99101
Represents a set of API client configurations.
100102
The access key and secret key are mandatory -- they must be set in either
101103
environment variables or as the explicit arguments.
@@ -129,9 +131,9 @@ class APIConfig:
129131
access key) to be automatically mounted upon any
130132
:func:`Kernel.get_or_create()
131133
<ai.backend.client.kernel.Kernel.get_or_create>` calls.
132-
'''
134+
"""
133135

134-
DEFAULTS = {
136+
DEFAULTS: Mapping[str, Any] = {
135137
'endpoint': 'https://api.backend.ai',
136138
'endpoint_type': 'api',
137139
'version': f'v{API_VERSION[0]}.{API_VERSION[1]}',
@@ -141,9 +143,12 @@ class APIConfig:
141143
'connection_timeout': 10.0,
142144
'read_timeout': None,
143145
}
144-
'''
146+
"""
145147
The default values except the access and secret keys.
146-
'''
148+
"""
149+
150+
_group: str
151+
_hash_type: str
147152

148153
def __init__(self, *,
149154
endpoint: Union[URL, str] = None,
@@ -159,17 +164,17 @@ def __init__(self, *,
159164
skip_sslcert_validation: bool = None,
160165
connection_timeout: float = None,
161166
read_timeout: float = None) -> None:
162-
from . import get_user_agent # noqa; to avoid circular imports
167+
from . import get_user_agent
163168
self._endpoints = (
164169
_clean_urls(endpoint) if endpoint else
165170
get_env('ENDPOINT', self.DEFAULTS['endpoint'], clean=_clean_urls))
166171
random.shuffle(self._endpoints)
167-
self._endpoint_type = endpoint_type if endpoint_type \
172+
self._endpoint_type = endpoint_type if endpoint_type is not None \
168173
else get_env('ENDPOINT_TYPE', self.DEFAULTS['endpoint_type'])
169-
self._domain = domain if domain else get_env('DOMAIN', self.DEFAULTS['domain'])
170-
self._group = group if group else get_env('GROUP', self.DEFAULTS['group'])
171-
self._version = version if version else self.DEFAULTS['version']
172-
self._user_agent = user_agent if user_agent else get_user_agent()
174+
self._domain = domain if domain is not None else get_env('DOMAIN', self.DEFAULTS['domain'])
175+
self._group = group if group is not None else get_env('GROUP', self.DEFAULTS['group'])
176+
self._version = version if version is not None else self.DEFAULTS['version']
177+
self._user_agent = user_agent if user_agent is not None else get_user_agent()
173178
if self._endpoint_type == 'api':
174179
self._access_key = access_key if access_key is not None \
175180
else get_env('ACCESS_KEY', '')
@@ -178,8 +183,8 @@ def __init__(self, *,
178183
else:
179184
self._access_key = 'dummy'
180185
self._secret_key = 'dummy'
181-
self._hash_type = hash_type.lower() if hash_type else \
182-
self.DEFAULTS['hash_type']
186+
self._hash_type = hash_type.lower() if hash_type is not None else \
187+
cast(str, self.DEFAULTS['hash_type'])
183188
arg_vfolders = set(vfolder_mounts) if vfolder_mounts else set()
184189
env_vfolders = set(get_env('VFOLDER_MOUNTS', [], clean=_clean_tokens))
185190
self._vfolder_mounts = [*(arg_vfolders | env_vfolders)]
@@ -198,16 +203,16 @@ def is_anonymous(self) -> bool:
198203

199204
@property
200205
def endpoint(self) -> URL:
201-
'''
206+
"""
202207
The currently active endpoint URL.
203208
This may change if there are multiple configured endpoints
204209
and the current one is not accessible.
205-
'''
210+
"""
206211
return self._endpoints[0]
207212

208213
@property
209214
def endpoints(self) -> Sequence[URL]:
210-
'''All configured endpoint URLs.'''
215+
"""All configured endpoint URLs."""
211216
return self._endpoints
212217

213218
def rotate_endpoints(self):
@@ -217,83 +222,83 @@ def rotate_endpoints(self):
217222

218223
@property
219224
def endpoint_type(self) -> str:
220-
'''
225+
"""
221226
The configured endpoint type.
222-
'''
227+
"""
223228
return self._endpoint_type
224229

225230
@property
226231
def domain(self) -> str:
227-
'''The configured domain.'''
232+
"""The configured domain."""
228233
return self._domain
229234

230235
@property
231236
def group(self) -> str:
232-
'''The configured group.'''
237+
"""The configured group."""
233238
return self._group
234239

235240
@property
236241
def user_agent(self) -> str:
237-
'''The configured user agent string.'''
242+
"""The configured user agent string."""
238243
return self._user_agent
239244

240245
@property
241246
def access_key(self) -> str:
242-
'''The configured API access key.'''
247+
"""The configured API access key."""
243248
return self._access_key
244249

245250
@property
246251
def secret_key(self) -> str:
247-
'''The configured API secret key.'''
252+
"""The configured API secret key."""
248253
return self._secret_key
249254

250255
@property
251256
def version(self) -> str:
252-
'''The configured API protocol version.'''
257+
"""The configured API protocol version."""
253258
return self._version
254259

255260
@property
256261
def hash_type(self) -> str:
257-
'''The configured hash algorithm for API authentication signatures.'''
262+
"""The configured hash algorithm for API authentication signatures."""
258263
return self._hash_type
259264

260265
@property
261-
def vfolder_mounts(self) -> Tuple[str, ...]:
262-
'''The configured auto-mounted vfolder list.'''
266+
def vfolder_mounts(self) -> Sequence[str]:
267+
"""The configured auto-mounted vfolder list."""
263268
return self._vfolder_mounts
264269

265270
@property
266271
def skip_sslcert_validation(self) -> bool:
267-
'''Whether to skip SSL certificate validation for the API gateway.'''
272+
"""Whether to skip SSL certificate validation for the API gateway."""
268273
return self._skip_sslcert_validation
269274

270275
@property
271276
def connection_timeout(self) -> float:
272-
'''The maximum allowed duration for making TCP connections to the server.'''
277+
"""The maximum allowed duration for making TCP connections to the server."""
273278
return self._connection_timeout
274279

275280
@property
276281
def read_timeout(self) -> float:
277-
'''The maximum allowed waiting time for the first byte of the response from the server.'''
282+
"""The maximum allowed waiting time for the first byte of the response from the server."""
278283
return self._read_timeout
279284

280285

281286
def get_config():
282-
'''
287+
"""
283288
Returns the configuration for the current process.
284289
If there is no explicitly set :class:`APIConfig` instance,
285290
it will generate a new one from the current environment variables
286291
and defaults.
287-
'''
292+
"""
288293
global _config
289294
if _config is None:
290295
_config = APIConfig()
291296
return _config
292297

293298

294299
def set_config(conf: APIConfig):
295-
'''
300+
"""
296301
Sets the configuration used throughout the current process.
297-
'''
302+
"""
298303
global _config
299304
_config = conf

src/ai/backend/client/func/base.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,19 @@ def _method(*args, **kwargs):
2828

2929

3030
def api_function(meth):
31-
'''
31+
"""
3232
Mark the wrapped method as the API function method.
33-
'''
33+
"""
3434
setattr(meth, '_backend_api', True)
3535
return meth
3636

3737

3838
class APIFunctionMeta(type):
39-
'''
39+
"""
4040
Converts all methods marked with :func:`api_function` into
4141
session-aware methods that are either plain Python functions
4242
or coroutines.
43-
'''
43+
"""
4444
_async = True
4545

4646
def __init__(cls, name, bases, attrs, **kwargs):

src/ai/backend/client/func/dotfile.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,16 @@ async def create(cls,
2929
}
3030
rqst.set_json(body)
3131
async with rqst.fetch() as resp:
32-
if resp.status == 200:
33-
await resp.json()
34-
return cls(path, owner_access_key=owner_access_key)
32+
await resp.json()
33+
return cls(path, owner_access_key=owner_access_key)
3534

3635
@api_function
3736
@classmethod
3837
async def list_dotfiles(cls) -> 'List[Mapping[str, str]]':
3938
rqst = Request(api_session.get(),
4039
'GET', '/user-config/dotfiles')
4140
async with rqst.fetch() as resp:
42-
if resp.status == 200:
43-
return await resp.json()
41+
return await resp.json()
4442

4543
def __init__(self, path: str, owner_access_key: str = None):
4644
self.path = path
@@ -55,8 +53,7 @@ async def get(self) -> str:
5553
'GET', f'/user-config/dotfiles',
5654
params=params)
5755
async with rqst.fetch() as resp:
58-
if resp.status == 200:
59-
return await resp.json()
56+
return await resp.json()
6057

6158
@api_function
6259
async def update(self, data: str, permission: str):

0 commit comments

Comments
 (0)