Skip to content

Commit dcfd8c9

Browse files
committed
httpx: rewrote patching to use wrapt instead of subclassing client
1 parent e4ece57 commit dcfd8c9

File tree

2 files changed

+214
-116
lines changed

2 files changed

+214
-116
lines changed

instrumentation/opentelemetry-instrumentation-httpx/src/opentelemetry/instrumentation/httpx/__init__.py

Lines changed: 196 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,10 @@ async def async_response_hook(span, request, response):
192192
"""
193193
import logging
194194
import typing
195-
from asyncio import iscoroutinefunction
196195
from types import TracebackType
197196

198197
import httpx
198+
from wrapt import wrap_function_wrapper
199199

200200
from opentelemetry.instrumentation._semconv import (
201201
_get_schema_url,
@@ -216,6 +216,7 @@ async def async_response_hook(span, request, response):
216216
from opentelemetry.instrumentation.utils import (
217217
http_status_to_status_code,
218218
is_http_instrumentation_enabled,
219+
unwrap,
219220
)
220221
from opentelemetry.propagate import inject
221222
from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE
@@ -728,44 +729,183 @@ def _instrument(self, **kwargs):
728729
``async_request_hook``: Async ``request_hook`` for ``httpx.AsyncClient``
729730
``async_response_hook``: Async``response_hook`` for ``httpx.AsyncClient``
730731
"""
731-
self._original_client = httpx.Client
732-
self._original_async_client = httpx.AsyncClient
733-
request_hook = kwargs.get("request_hook")
734-
response_hook = kwargs.get("response_hook")
735-
async_request_hook = kwargs.get("async_request_hook")
736-
async_response_hook = kwargs.get("async_response_hook")
737-
if callable(request_hook):
738-
_InstrumentedClient._request_hook = request_hook
739-
if callable(async_request_hook) and iscoroutinefunction(
740-
async_request_hook
741-
):
742-
_InstrumentedAsyncClient._request_hook = async_request_hook
743-
if callable(response_hook):
744-
_InstrumentedClient._response_hook = response_hook
745-
if callable(async_response_hook) and iscoroutinefunction(
746-
async_response_hook
747-
):
748-
_InstrumentedAsyncClient._response_hook = async_response_hook
749732
tracer_provider = kwargs.get("tracer_provider")
750-
_InstrumentedClient._tracer_provider = tracer_provider
751-
_InstrumentedAsyncClient._tracer_provider = tracer_provider
752-
# Intentionally using a private attribute here, see:
753-
# https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2538#discussion_r1610603719
754-
httpx.Client = httpx._api.Client = _InstrumentedClient
755-
httpx.AsyncClient = _InstrumentedAsyncClient
733+
self._request_hook = kwargs.get("request_hook")
734+
self._response_hook = kwargs.get("response_hook")
735+
self._async_request_hook = kwargs.get("async_request_hook")
736+
self._async_response_hook = kwargs.get("async_response_hook")
737+
738+
if getattr(self, "__instrumented", False):
739+
print("already instrumented")
740+
return
741+
742+
_OpenTelemetrySemanticConventionStability._initialize()
743+
self._sem_conv_opt_in_mode = _OpenTelemetrySemanticConventionStability._get_opentelemetry_stability_opt_in_mode(
744+
_OpenTelemetryStabilitySignalType.HTTP,
745+
)
746+
self._tracer = get_tracer(
747+
__name__,
748+
instrumenting_library_version=__version__,
749+
tracer_provider=tracer_provider,
750+
schema_url=_get_schema_url(self._sem_conv_opt_in_mode),
751+
)
752+
753+
wrap_function_wrapper(
754+
"httpx",
755+
"HTTPTransport.handle_request",
756+
self._handle_request_wrapper,
757+
)
758+
wrap_function_wrapper(
759+
"httpx",
760+
"AsyncHTTPTransport.handle_async_request",
761+
self._handle_async_request_wrapper,
762+
)
763+
764+
self.__instrumented = True
756765

757766
def _uninstrument(self, **kwargs):
758-
httpx.Client = httpx._api.Client = self._original_client
759-
httpx.AsyncClient = self._original_async_client
760-
_InstrumentedClient._tracer_provider = None
761-
_InstrumentedClient._request_hook = None
762-
_InstrumentedClient._response_hook = None
763-
_InstrumentedAsyncClient._tracer_provider = None
764-
_InstrumentedAsyncClient._request_hook = None
765-
_InstrumentedAsyncClient._response_hook = None
767+
import httpx
768+
769+
unwrap(httpx.HTTPTransport, "handle_request")
770+
unwrap(httpx.AsyncHTTPTransport, "handle_async_request")
771+
772+
def _handle_request_wrapper(self, wrapped, instance, args, kwargs):
773+
if not is_http_instrumentation_enabled():
774+
return wrapped(*args, **kwargs)
775+
776+
method, url, headers, stream, extensions = _extract_parameters(
777+
args, kwargs
778+
)
779+
method_original = method.decode()
780+
span_name = _get_default_span_name(method_original)
781+
span_attributes = {}
782+
# apply http client response attributes according to semconv
783+
_apply_request_client_attributes_to_span(
784+
span_attributes,
785+
url,
786+
method_original,
787+
self._sem_conv_opt_in_mode,
788+
)
789+
790+
request_info = RequestInfo(method, url, headers, stream, extensions)
791+
792+
with self._tracer.start_as_current_span(
793+
span_name, kind=SpanKind.CLIENT, attributes=span_attributes
794+
) as span:
795+
exception = None
796+
if callable(self._request_hook):
797+
self._request_hook(span, request_info)
798+
799+
_inject_propagation_headers(headers, args, kwargs)
800+
801+
try:
802+
response = wrapped(*args, **kwargs)
803+
except Exception as exc: # pylint: disable=W0703
804+
exception = exc
805+
response = getattr(exc, "response", None)
806+
807+
if isinstance(response, (httpx.Response, tuple)):
808+
status_code, headers, stream, extensions, http_version = (
809+
_extract_response(response)
810+
)
811+
812+
if span.is_recording():
813+
# apply http client response attributes according to semconv
814+
_apply_response_client_attributes_to_span(
815+
span,
816+
status_code,
817+
http_version,
818+
self._sem_conv_opt_in_mode,
819+
)
820+
if callable(self._response_hook):
821+
self._response_hook(
822+
span,
823+
request_info,
824+
ResponseInfo(status_code, headers, stream, extensions),
825+
)
826+
827+
if exception:
828+
if span.is_recording() and _report_new(
829+
self._sem_conv_opt_in_mode
830+
):
831+
span.set_attribute(
832+
ERROR_TYPE, type(exception).__qualname__
833+
)
834+
raise exception.with_traceback(exception.__traceback__)
835+
836+
return response
837+
838+
async def _handle_async_request_wrapper(
839+
self, wrapped, instance, args, kwargs
840+
):
841+
if not is_http_instrumentation_enabled():
842+
return await wrapped(*args, **kwargs)
843+
844+
method, url, headers, stream, extensions = _extract_parameters(
845+
args, kwargs
846+
)
847+
method_original = method.decode()
848+
span_name = _get_default_span_name(method_original)
849+
span_attributes = {}
850+
# apply http client response attributes according to semconv
851+
_apply_request_client_attributes_to_span(
852+
span_attributes,
853+
url,
854+
method_original,
855+
self._sem_conv_opt_in_mode,
856+
)
857+
858+
request_info = RequestInfo(method, url, headers, stream, extensions)
859+
860+
with self._tracer.start_as_current_span(
861+
span_name, kind=SpanKind.CLIENT, attributes=span_attributes
862+
) as span:
863+
exception = None
864+
if callable(self._async_request_hook):
865+
await self._async_request_hook(span, request_info)
866+
867+
_inject_propagation_headers(headers, args, kwargs)
868+
869+
try:
870+
response = await wrapped(*args, **kwargs)
871+
except Exception as exc: # pylint: disable=W0703
872+
exception = exc
873+
response = getattr(exc, "response", None)
874+
875+
if isinstance(response, (httpx.Response, tuple)):
876+
status_code, headers, stream, extensions, http_version = (
877+
_extract_response(response)
878+
)
879+
880+
if span.is_recording():
881+
# apply http client response attributes according to semconv
882+
_apply_response_client_attributes_to_span(
883+
span,
884+
status_code,
885+
http_version,
886+
self._sem_conv_opt_in_mode,
887+
)
888+
889+
if callable(self._async_response_hook):
890+
await self._async_response_hook(
891+
span,
892+
request_info,
893+
ResponseInfo(status_code, headers, stream, extensions),
894+
)
895+
896+
if exception:
897+
if span.is_recording() and _report_new(
898+
self._sem_conv_opt_in_mode
899+
):
900+
span.set_attribute(
901+
ERROR_TYPE, type(exception).__qualname__
902+
)
903+
raise exception.with_traceback(exception.__traceback__)
904+
905+
return response
766906

767-
@staticmethod
768907
def instrument_client(
908+
self,
769909
client: typing.Union[httpx.Client, httpx.AsyncClient],
770910
tracer_provider: TracerProvider = None,
771911
request_hook: typing.Union[
@@ -785,67 +925,27 @@ def instrument_client(
785925
response_hook: A hook that receives the span, request, and response
786926
that is called right before the span ends
787927
"""
788-
# pylint: disable=protected-access
789-
if not hasattr(client, "_is_instrumented_by_opentelemetry"):
790-
client._is_instrumented_by_opentelemetry = False
791928

792-
if not client._is_instrumented_by_opentelemetry:
793-
if isinstance(client, httpx.Client):
794-
client._original_transport = client._transport
795-
client._original_mounts = client._mounts.copy()
796-
transport = client._transport or httpx.HTTPTransport()
797-
client._transport = SyncOpenTelemetryTransport(
798-
transport,
799-
tracer_provider=tracer_provider,
800-
request_hook=request_hook,
801-
response_hook=response_hook,
802-
)
803-
client._is_instrumented_by_opentelemetry = True
804-
client._mounts.update(
805-
{
806-
url_pattern: (
807-
SyncOpenTelemetryTransport(
808-
transport,
809-
tracer_provider=tracer_provider,
810-
request_hook=request_hook,
811-
response_hook=response_hook,
812-
)
813-
if transport is not None
814-
else transport
815-
)
816-
for url_pattern, transport in client._original_mounts.items()
817-
}
818-
)
819-
820-
if isinstance(client, httpx.AsyncClient):
821-
transport = client._transport or httpx.AsyncHTTPTransport()
822-
client._original_mounts = client._mounts.copy()
823-
client._transport = AsyncOpenTelemetryTransport(
824-
transport,
825-
tracer_provider=tracer_provider,
826-
request_hook=request_hook,
827-
response_hook=response_hook,
828-
)
829-
client._is_instrumented_by_opentelemetry = True
830-
client._mounts.update(
831-
{
832-
url_pattern: (
833-
AsyncOpenTelemetryTransport(
834-
transport,
835-
tracer_provider=tracer_provider,
836-
request_hook=request_hook,
837-
response_hook=response_hook,
838-
)
839-
if transport is not None
840-
else transport
841-
)
842-
for url_pattern, transport in client._original_mounts.items()
843-
}
844-
)
845-
else:
929+
if getattr(client, "_is_instrumented_by_opentelemetry", False):
846930
_logger.warning(
847931
"Attempting to instrument Httpx client while already instrumented"
848932
)
933+
return
934+
935+
if hasattr(client._transport, "handle_request"):
936+
wrap_function_wrapper(
937+
client._transport,
938+
"handle_request",
939+
self._handle_request_wrapper,
940+
)
941+
client._is_instrumented_by_opentelemetry = True
942+
if hasattr(client._transport, "handle_async_request"):
943+
wrap_function_wrapper(
944+
client._transport,
945+
"handle_async_request",
946+
self._handle_async_request_wrapper,
947+
)
948+
client._is_instrumented_by_opentelemetry = True
849949

850950
@staticmethod
851951
def uninstrument_client(
@@ -856,15 +956,9 @@ def uninstrument_client(
856956
Args:
857957
client: The httpx Client or AsyncClient instance
858958
"""
859-
if hasattr(client, "_original_transport"):
860-
client._transport = client._original_transport
861-
del client._original_transport
959+
if hasattr(client._transport, "handle_request"):
960+
unwrap(client._transport, "handle_request")
961+
client._is_instrumented_by_opentelemetry = False
962+
elif hasattr(client._transport, "handle_async_request"):
963+
unwrap(client._transport, "handle_async_request")
862964
client._is_instrumented_by_opentelemetry = False
863-
if hasattr(client, "_original_mounts"):
864-
client._mounts = client._original_mounts.copy()
865-
del client._original_mounts
866-
else:
867-
_logger.warning(
868-
"Attempting to uninstrument Httpx "
869-
"client while already uninstrumented"
870-
)

0 commit comments

Comments
 (0)