Skip to content

Commit a3ecbc1

Browse files
authored
Add SQLAlchemy multithreading test (#468)
1 parent 3d7cc64 commit a3ecbc1

File tree

1 file changed

+46
-2
lines changed
  • tests/opentelemetry-docker-tests/tests/sqlalchemy_tests

1 file changed

+46
-2
lines changed

tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
# limitations under the License.
1414

1515
import contextlib
16+
import logging
17+
import threading
1618

17-
from sqlalchemy import Column, Integer, String, create_engine
19+
from sqlalchemy import Column, Integer, String, create_engine, insert
1820
from sqlalchemy.ext.declarative import declarative_base
19-
from sqlalchemy.orm import sessionmaker
21+
from sqlalchemy.orm import close_all_sessions, scoped_session, sessionmaker
2022

2123
from opentelemetry import trace
2224
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
@@ -199,3 +201,45 @@ def test_parent(self):
199201
self.assertEqual(parent_span.instrumentation_info.name, "sqlalch_svc")
200202

201203
self.assertEqual(child_span.name, "SELECT " + self.SQL_DB)
204+
205+
def test_multithreading(self):
206+
"""Ensure spans are captured correctly in a multithreading scenario
207+
208+
We also expect no logged warnings about calling end() on an ended span.
209+
"""
210+
211+
if self.VENDOR == "sqlite":
212+
return
213+
214+
def insert_player(session):
215+
_session = session()
216+
player = Player(name="Player")
217+
_session.add(player)
218+
_session.commit()
219+
_session.query(Player).all()
220+
221+
def insert_players(session):
222+
_session = session()
223+
players = []
224+
for player_number in range(3):
225+
players.append(Player(name=f"Player {player_number}"))
226+
_session.add_all(players)
227+
_session.commit()
228+
229+
session_factory = sessionmaker(bind=self.engine)
230+
# pylint: disable=invalid-name
231+
Session = scoped_session(session_factory)
232+
thread_one = threading.Thread(target=insert_player, args=(Session,))
233+
thread_two = threading.Thread(target=insert_players, args=(Session,))
234+
235+
logger = logging.getLogger("opentelemetry.sdk.trace")
236+
with self.assertRaises(AssertionError):
237+
with self.assertLogs(logger, level="WARNING"):
238+
thread_one.start()
239+
thread_two.start()
240+
thread_one.join()
241+
thread_two.join()
242+
close_all_sessions()
243+
244+
spans = self.memory_exporter.get_finished_spans()
245+
self.assertEqual(len(spans), 5)

0 commit comments

Comments
 (0)