|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import contextlib
|
| 16 | +import logging |
| 17 | +import threading |
16 | 18 |
|
17 |
| -from sqlalchemy import Column, Integer, String, create_engine |
| 19 | +from sqlalchemy import Column, Integer, String, create_engine, insert |
18 | 20 | from sqlalchemy.ext.declarative import declarative_base
|
19 |
| -from sqlalchemy.orm import sessionmaker |
| 21 | +from sqlalchemy.orm import close_all_sessions, scoped_session, sessionmaker |
20 | 22 |
|
21 | 23 | from opentelemetry import trace
|
22 | 24 | from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
|
@@ -199,3 +201,45 @@ def test_parent(self):
|
199 | 201 | self.assertEqual(parent_span.instrumentation_info.name, "sqlalch_svc")
|
200 | 202 |
|
201 | 203 | self.assertEqual(child_span.name, "SELECT " + self.SQL_DB)
|
| 204 | + |
| 205 | + def test_multithreading(self): |
| 206 | + """Ensure spans are captured correctly in a multithreading scenario |
| 207 | +
|
| 208 | + We also expect no logged warnings about calling end() on an ended span. |
| 209 | + """ |
| 210 | + |
| 211 | + if self.VENDOR == "sqlite": |
| 212 | + return |
| 213 | + |
| 214 | + def insert_player(session): |
| 215 | + _session = session() |
| 216 | + player = Player(name="Player") |
| 217 | + _session.add(player) |
| 218 | + _session.commit() |
| 219 | + _session.query(Player).all() |
| 220 | + |
| 221 | + def insert_players(session): |
| 222 | + _session = session() |
| 223 | + players = [] |
| 224 | + for player_number in range(3): |
| 225 | + players.append(Player(name=f"Player {player_number}")) |
| 226 | + _session.add_all(players) |
| 227 | + _session.commit() |
| 228 | + |
| 229 | + session_factory = sessionmaker(bind=self.engine) |
| 230 | + # pylint: disable=invalid-name |
| 231 | + Session = scoped_session(session_factory) |
| 232 | + thread_one = threading.Thread(target=insert_player, args=(Session,)) |
| 233 | + thread_two = threading.Thread(target=insert_players, args=(Session,)) |
| 234 | + |
| 235 | + logger = logging.getLogger("opentelemetry.sdk.trace") |
| 236 | + with self.assertRaises(AssertionError): |
| 237 | + with self.assertLogs(logger, level="WARNING"): |
| 238 | + thread_one.start() |
| 239 | + thread_two.start() |
| 240 | + thread_one.join() |
| 241 | + thread_two.join() |
| 242 | + close_all_sessions() |
| 243 | + |
| 244 | + spans = self.memory_exporter.get_finished_spans() |
| 245 | + self.assertEqual(len(spans), 5) |
0 commit comments