@@ -114,18 +114,19 @@ class DisplayFixture:
114
114
"""A fixture for running web-based tests using ``playwright``"""
115
115
116
116
_exit_stack : AsyncExitStack
117
- _context : BrowserContext
118
117
119
118
def __init__ (
120
119
self ,
121
120
server : ServerFixture | None = None ,
122
- browser : Browser | None = None ,
121
+ driver : Browser | BrowserContext | Page | None = None ,
123
122
) -> None :
124
123
if server is not None :
125
124
self .server = server
126
- if browser is not None :
127
- self .browser = browser
128
-
125
+ if driver is not None :
126
+ if isinstance (driver , Page ):
127
+ self ._page = driver
128
+ else :
129
+ self ._browser = browser
129
130
self ._next_view_id = 0
130
131
131
132
async def show (
@@ -137,22 +138,21 @@ async def show(
137
138
view_id = f"display-{ self ._next_view_id } "
138
139
self .server .mount (lambda : html .div ({"id" : view_id }, component ()))
139
140
140
- page = await self ._context .new_page ()
141
-
142
- await page .goto (self .server .url (query = query ))
143
- await page .wait_for_selector (f"#{ view_id } " )
141
+ await self ._page .goto (self .server .url (query = query ))
142
+ await self ._page .wait_for_selector (f"#{ view_id } " )
144
143
145
- return page
144
+ return self . _page
146
145
147
146
async def __aenter__ (self : _Self ) -> _Self :
148
147
es = self ._exit_stack = AsyncExitStack ()
149
148
150
- if not hasattr (self , "browser" ):
151
- pw = await es .enter_async_context (async_playwright ())
152
- self .browser = await pw .chromium .launch ()
153
-
154
- self ._context = await self .browser .new_context ()
155
- es .push_async_callback (self ._context .close )
149
+ if not hasattr (self , "_page" ):
150
+ if not hasattr (self , "_browser" ):
151
+ pw = await es .enter_async_context (async_playwright ())
152
+ browser = await pw .chromium .launch ()
153
+ else :
154
+ browser = self ._browser
155
+ self ._page = await browser .new_page ()
156
156
157
157
if not hasattr (self , "server" ):
158
158
self .server = ServerFixture (** self ._server_options )
@@ -166,6 +166,7 @@ async def __aexit__(
166
166
exc_value : BaseException | None ,
167
167
traceback : TracebackType | None ,
168
168
) -> None :
169
+ self .server .mount (None )
169
170
await self ._exit_stack .aclose ()
170
171
171
172
0 commit comments