|
12 | 12 | import tarfile
|
13 | 13 | import time
|
14 | 14 | import zipfile
|
15 |
| -from typing import Any, AsyncIterator, Optional, Type |
| 15 | +from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Type |
16 | 16 | from unittest import mock
|
17 | 17 |
|
18 | 18 | import pytest
|
|
21 | 21 |
|
22 | 22 | import aiohttp
|
23 | 23 | from aiohttp import Fingerprint, ServerFingerprintMismatch, hdrs, web
|
24 |
| -from aiohttp.abc import AbstractResolver |
| 24 | +from aiohttp.abc import AbstractResolver, ResolveResult |
25 | 25 | from aiohttp.client_exceptions import (
|
26 | 26 | ClientResponseError,
|
27 | 27 | InvalidURL,
|
|
35 | 35 | from aiohttp.client_reqrep import ClientRequest
|
36 | 36 | from aiohttp.connector import Connection
|
37 | 37 | from aiohttp.http_writer import StreamWriter
|
38 |
| -from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer, TestClient |
39 |
| -from aiohttp.test_utils import unused_port |
| 38 | +from aiohttp.pytest_plugin import AiohttpClient, AiohttpServer |
| 39 | +from aiohttp.test_utils import TestClient, TestServer, unused_port |
| 40 | +from aiohttp.typedefs import Handler |
40 | 41 |
|
41 | 42 |
|
42 | 43 | @pytest.fixture
|
@@ -2888,6 +2889,68 @@ async def test_creds_in_auth_and_url() -> None:
|
2888 | 2889 | await session.close()
|
2889 | 2890 |
|
2890 | 2891 |
|
| 2892 | +async def test_creds_in_auth_and_redirect_url( |
| 2893 | + create_server_for_url_and_handler: Callable[[URL, Handler], Awaitable[TestServer]], |
| 2894 | +) -> None: |
| 2895 | + """Verify that credentials in redirect URLs can and do override any previous credentials.""" |
| 2896 | + url_from = URL("http://example.com") |
| 2897 | + url_to = URL( "http://[email protected]") |
| 2898 | + redirected = False |
| 2899 | + |
| 2900 | + async def srv(request: web.Request) -> web.Response: |
| 2901 | + nonlocal redirected |
| 2902 | + |
| 2903 | + assert request.host == url_from.host |
| 2904 | + |
| 2905 | + if not redirected: |
| 2906 | + redirected = True |
| 2907 | + raise web.HTTPMovedPermanently(url_to) |
| 2908 | + |
| 2909 | + return web.Response() |
| 2910 | + |
| 2911 | + server = await create_server_for_url_and_handler(url_from, srv) |
| 2912 | + |
| 2913 | + etc_hosts = { |
| 2914 | + (url_from.host, 80): server, |
| 2915 | + } |
| 2916 | + |
| 2917 | + class FakeResolver(AbstractResolver): |
| 2918 | + async def resolve( |
| 2919 | + self, |
| 2920 | + host: str, |
| 2921 | + port: int = 0, |
| 2922 | + family: socket.AddressFamily = socket.AF_INET, |
| 2923 | + ) -> List[ResolveResult]: |
| 2924 | + server = etc_hosts[(host, port)] |
| 2925 | + assert server.port is not None |
| 2926 | + |
| 2927 | + return [ |
| 2928 | + { |
| 2929 | + "hostname": host, |
| 2930 | + "host": server.host, |
| 2931 | + "port": server.port, |
| 2932 | + "family": socket.AF_INET, |
| 2933 | + "proto": 0, |
| 2934 | + "flags": socket.AI_NUMERICHOST, |
| 2935 | + } |
| 2936 | + ] |
| 2937 | + |
| 2938 | + async def close(self) -> None: |
| 2939 | + """Dummy""" |
| 2940 | + |
| 2941 | + connector = aiohttp.TCPConnector(resolver=FakeResolver(), ssl=False) |
| 2942 | + |
| 2943 | + async with aiohttp.ClientSession(connector=connector) as client, client.get( |
| 2944 | + url_from, auth=aiohttp.BasicAuth("user", "pass") |
| 2945 | + ) as resp: |
| 2946 | + assert len(resp.history) == 1 |
| 2947 | + assert str(resp.url) == "http://example.com" |
| 2948 | + assert resp.status == 200 |
| 2949 | + assert ( |
| 2950 | + resp.request_info.headers.get("authorization") == "Basic dXNlcjo=" |
| 2951 | + ), "Expected redirect credentials to take precedence over provided auth" |
| 2952 | + |
| 2953 | + |
2891 | 2954 | @pytest.fixture
|
2892 | 2955 | def create_server_for_url_and_handler(aiohttp_server, tls_certificate_authority):
|
2893 | 2956 | def create(url, srv):
|
|
0 commit comments