diff --git a/.release-please-manifest.json b/.release-please-manifest.json index c4bf1b6c04..0d3c59d336 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.13.3" + ".": "1.13.4" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 201757c90d..2c80f70cc6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,34 @@ # Changelog +## 1.13.4 (2024-03-13) + +Full Changelog: [v1.13.3...v1.13.4](https://github.com/openai/openai-python/compare/v1.13.3...v1.13.4) + +### Bug Fixes + +* **streaming:** improve error messages ([#1218](https://github.com/openai/openai-python/issues/1218)) ([4f5ff29](https://github.com/openai/openai-python/commit/4f5ff298601b5a8bfbf0a9d0c0d1329d1502a205)) + + +### Chores + +* **api:** update docs ([#1212](https://github.com/openai/openai-python/issues/1212)) ([71236e0](https://github.com/openai/openai-python/commit/71236e0de4012a249af4c1ffd95973a8ba4fa61f)) +* **client:** improve error message for invalid http_client argument ([#1216](https://github.com/openai/openai-python/issues/1216)) ([d0c928a](https://github.com/openai/openai-python/commit/d0c928abbd99020fe828350f3adfd10c638a2eed)) +* **docs:** mention install from git repo ([#1203](https://github.com/openai/openai-python/issues/1203)) ([3ab6f44](https://github.com/openai/openai-python/commit/3ab6f447ffd8d2394e58416e401e545a99ec85af)) +* export NOT_GIVEN sentinel value ([#1223](https://github.com/openai/openai-python/issues/1223)) ([8a4f76f](https://github.com/openai/openai-python/commit/8a4f76f992c66f20cd6aa070c8dc4839e4cf9f3c)) +* **internal:** add core support for deserializing into number response ([#1219](https://github.com/openai/openai-python/issues/1219)) ([004bc92](https://github.com/openai/openai-python/commit/004bc924ea579852b9266ca11aea93463cf75104)) +* **internal:** bump pyright ([#1221](https://github.com/openai/openai-python/issues/1221)) ([3c2e815](https://github.com/openai/openai-python/commit/3c2e815311ace4ff81ccd446b23ff50a4e099485)) +* **internal:** improve deserialisation of discriminated unions ([#1227](https://github.com/openai/openai-python/issues/1227)) ([4767259](https://github.com/openai/openai-python/commit/4767259d25ac135550b37b15e4c0497e5ff0330d)) +* **internal:** minor core client restructuring ([#1199](https://github.com/openai/openai-python/issues/1199)) ([4314cdc](https://github.com/openai/openai-python/commit/4314cdcd522537e6cbbd87206d5bb236f672ce05)) +* **internal:** split up transforms into sync / async ([#1210](https://github.com/openai/openai-python/issues/1210)) ([7853a83](https://github.com/openai/openai-python/commit/7853a8358864957cc183581bdf7c03810a7b2756)) +* **internal:** support more input types ([#1211](https://github.com/openai/openai-python/issues/1211)) ([d0e4baa](https://github.com/openai/openai-python/commit/d0e4baa40d32c2da0ce5ceef8e0c7193b98f2b5a)) +* **internal:** support parsing Annotated types ([#1222](https://github.com/openai/openai-python/issues/1222)) ([8598f81](https://github.com/openai/openai-python/commit/8598f81841eeab0ab00eb21fdec7e8756ffde909)) +* **types:** include discriminators in unions ([#1228](https://github.com/openai/openai-python/issues/1228)) ([3ba0dcc](https://github.com/openai/openai-python/commit/3ba0dcc19a2af0ef869c77da2805278f71ee96c2)) + + +### Documentation + +* **contributing:** improve wording ([#1201](https://github.com/openai/openai-python/issues/1201)) ([95a1e0e](https://github.com/openai/openai-python/commit/95a1e0ea8e5446c413606847ebf9e35afbc62bf9)) + ## 1.13.3 (2024-02-28) Full Changelog: [v1.13.2...v1.13.3](https://github.com/openai/openai-python/compare/v1.13.2...v1.13.3) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 914ab67053..7ab73dbf4c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -59,7 +59,7 @@ If you’d like to use the repository from source, you can either install from g To install via git: ```bash -pip install git+ssh://git@github.com:openai/openai-python.git +pip install git+ssh://git@github.com/openai/openai-python.git ``` Alternatively, you can build from source and install the wheel file: @@ -82,7 +82,7 @@ pip install ./path-to-wheel-file.whl ## Running tests -Most tests will require you to [setup a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. +Most tests require you to [set up a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. ```bash # you will need npm installed @@ -117,7 +117,7 @@ the changes aren't made through the automated pipeline, you may want to make rel ### Publish with a GitHub workflow -You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/openai/openai-python/actions/workflows/publish-pypi.yml). This will require a setup organization or repository secret to be set up. +You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/openai/openai-python/actions/workflows/publish-pypi.yml). This requires a setup organization or repository secret to be set up. ### Publish manually diff --git a/README.md b/README.md index 0e06cd5631..7d6c896d50 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ The REST API documentation can be found [on platform.openai.com](https://platfor > The SDK was rewritten in v1, which was released November 6th 2023. See the [v1 migration guide](https://github.com/openai/openai-python/discussions/742), which includes scripts to automatically update your code. ```sh +# install from PyPI pip install openai ``` diff --git a/pyproject.toml b/pyproject.toml index 171ede0aa4..9155a9aa22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "openai" -version = "1.13.3" +version = "1.13.4" description = "The official Python library for the openai API" readme = "README.md" license = "Apache-2.0" diff --git a/requirements-dev.lock b/requirements-dev.lock index fa95964d07..0392de573f 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -17,7 +17,7 @@ argcomplete==3.1.2 # via nox attrs==23.1.0 # via pytest -azure-core==1.30.0 +azure-core==1.30.1 # via azure-identity azure-identity==1.15.0 certifi==2023.7.22 @@ -96,7 +96,7 @@ pydantic-core==2.10.1 # via pydantic pyjwt==2.8.0 # via msal -pyright==1.1.351 +pyright==1.1.353 pytest==7.1.1 # via pytest-asyncio pytest-asyncio==0.21.1 diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 118fe8ee93..1037e3cdd5 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -6,7 +6,7 @@ from typing_extensions import override from . import types -from ._types import NoneType, Transport, ProxiesTypes +from ._types import NOT_GIVEN, NoneType, NotGiven, Transport, ProxiesTypes from ._utils import file_from_path from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions from ._models import BaseModel @@ -37,6 +37,8 @@ "NoneType", "Transport", "ProxiesTypes", + "NotGiven", + "NOT_GIVEN", "OpenAIError", "APIError", "APIStatusError", diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 73bd2411fd..f431128eef 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -79,7 +79,7 @@ RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER, ) -from ._streaming import Stream, AsyncStream +from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder from ._exceptions import ( APIStatusError, APITimeoutError, @@ -431,6 +431,9 @@ def _prepare_url(self, url: str) -> URL: return merge_url + def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder: + return SSEDecoder() + def _build_request( self, options: FinalRequestOptions, @@ -777,6 +780,11 @@ def __init__( else: timeout = DEFAULT_TIMEOUT + if http_client is not None and not isinstance(http_client, httpx.Client): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( + f"Invalid `http_client` argument; Expected an instance of `httpx.Client` but got {type(http_client)}" + ) + super().__init__( version=version, limits=limits, @@ -1319,6 +1327,11 @@ def __init__( else: timeout = DEFAULT_TIMEOUT + if http_client is not None and not isinstance(http_client, httpx.AsyncClient): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( + f"Invalid `http_client` argument; Expected an instance of `httpx.AsyncClient` but got {type(http_client)}" + ) + super().__init__( version=version, base_url=base_url, diff --git a/src/openai/_files.py b/src/openai/_files.py index bebfb19501..ad7b668b4b 100644 --- a/src/openai/_files.py +++ b/src/openai/_files.py @@ -13,12 +13,17 @@ FileContent, RequestFiles, HttpxFileTypes, + Base64FileInput, HttpxFileContent, HttpxRequestFiles, ) from ._utils import is_tuple_t, is_mapping_t, is_sequence_t +def is_base64_file_input(obj: object) -> TypeGuard[Base64FileInput]: + return isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) + + def is_file_content(obj: object) -> TypeGuard[FileContent]: return ( isinstance(obj, bytes) or isinstance(obj, tuple) or isinstance(obj, io.IOBase) or isinstance(obj, os.PathLike) diff --git a/src/openai/_legacy_response.py b/src/openai/_legacy_response.py index 6eaa691d9f..4585cd7423 100644 --- a/src/openai/_legacy_response.py +++ b/src/openai/_legacy_response.py @@ -13,7 +13,7 @@ import pydantic from ._types import NoneType -from ._utils import is_given +from ._utils import is_given, extract_type_arg, is_annotated_type from ._models import BaseModel, is_basemodel from ._constants import RAW_RESPONSE_HEADER from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type @@ -107,6 +107,8 @@ class MyModel(BaseModel): - `list` - `Union` - `str` + - `int` + - `float` - `httpx.Response` """ cache_key = to if to is not None else self._cast_to @@ -172,6 +174,10 @@ def elapsed(self) -> datetime.timedelta: return self.http_response.elapsed def _parse(self, *, to: type[_T] | None = None) -> R | _T: + # unwrap `Annotated[T, ...]` -> `T` + if to and is_annotated_type(to): + to = extract_type_arg(to, 0) + if self._stream: if to: if not is_stream_class_type(to): @@ -213,6 +219,11 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: ) cast_to = to if to is not None else self._cast_to + + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(cast_to): + cast_to = extract_type_arg(cast_to, 0) + if cast_to is NoneType: return cast(R, None) @@ -220,6 +231,12 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: if cast_to == str: return cast(R, response.text) + if cast_to == int: + return cast(R, int(response.text)) + + if cast_to == float: + return cast(R, float(response.text)) + origin = get_origin(cast_to) or cast_to if inspect.isclass(origin) and issubclass(origin, HttpxBinaryResponseContent): @@ -307,7 +324,7 @@ def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIRespon @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "true" kwargs["extra_headers"] = extra_headers @@ -324,7 +341,7 @@ def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P @functools.wraps(func) async def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "true" kwargs["extra_headers"] = extra_headers diff --git a/src/openai/_models.py b/src/openai/_models.py index 810891497a..88afa40810 100644 --- a/src/openai/_models.py +++ b/src/openai/_models.py @@ -10,6 +10,7 @@ Protocol, Required, TypedDict, + TypeGuard, final, override, runtime_checkable, @@ -30,7 +31,18 @@ AnyMapping, HttpxRequestFiles, ) -from ._utils import is_list, is_given, is_mapping, parse_date, parse_datetime, strip_not_given +from ._utils import ( + PropertyInfo, + is_list, + is_given, + is_mapping, + parse_date, + parse_datetime, + strip_not_given, + extract_type_arg, + is_annotated_type, + strip_annotated_type, +) from ._compat import ( PYDANTIC_V2, ConfigDict, @@ -46,6 +58,9 @@ ) from ._constants import RAW_RESPONSE_HEADER +if TYPE_CHECKING: + from pydantic_core.core_schema import ModelField, ModelFieldsSchema + __all__ = ["BaseModel", "GenericModel"] _T = TypeVar("_T") @@ -259,7 +274,6 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object: def is_basemodel(type_: type) -> bool: """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" - origin = get_origin(type_) or type_ if is_union(type_): for variant in get_args(type_): if is_basemodel(variant): @@ -267,6 +281,11 @@ def is_basemodel(type_: type) -> bool: return False + return is_basemodel_type(type_) + + +def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]: + origin = get_origin(type_) or type_ return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) @@ -275,6 +294,12 @@ def construct_type(*, value: object, type_: type) -> object: If the given value does not match the expected type then it is returned as-is. """ + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(type_): + meta = get_args(type_)[1:] + type_ = extract_type_arg(type_, 0) + else: + meta = tuple() # we need to use the origin class for any types that are subscripted generics # e.g. Dict[str, object] @@ -287,6 +312,28 @@ def construct_type(*, value: object, type_: type) -> object: except Exception: pass + # if the type is a discriminated union then we want to construct the right variant + # in the union, even if the data doesn't match exactly, otherwise we'd break code + # that relies on the constructed class types, e.g. + # + # class FooType: + # kind: Literal['foo'] + # value: str + # + # class BarType: + # kind: Literal['bar'] + # value: int + # + # without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then + # we'd end up constructing `FooType` when it should be `BarType`. + discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta) + if discriminator and is_mapping(value): + variant_value = value.get(discriminator.field_alias_from or discriminator.field_name) + if variant_value and isinstance(variant_value, str): + variant_type = discriminator.mapping.get(variant_value) + if variant_type: + return construct_type(type_=variant_type, value=value) + # if the data is not valid, use the first variant that doesn't fail while deserializing for variant in args: try: @@ -344,6 +391,129 @@ def construct_type(*, value: object, type_: type) -> object: return value +@runtime_checkable +class CachedDiscriminatorType(Protocol): + __discriminator__: DiscriminatorDetails + + +class DiscriminatorDetails: + field_name: str + """The name of the discriminator field in the variant class, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] + ``` + + Will result in field_name='type' + """ + + field_alias_from: str | None + """The name of the discriminator field in the API response, e.g. + + ```py + class Foo(BaseModel): + type: Literal['foo'] = Field(alias='type_from_api') + ``` + + Will result in field_alias_from='type_from_api' + """ + + mapping: dict[str, type] + """Mapping of discriminator value to variant type, e.g. + + {'foo': FooVariant, 'bar': BarVariant} + """ + + def __init__( + self, + *, + mapping: dict[str, type], + discriminator_field: str, + discriminator_alias: str | None, + ) -> None: + self.mapping = mapping + self.field_name = discriminator_field + self.field_alias_from = discriminator_alias + + +def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: + if isinstance(union, CachedDiscriminatorType): + return union.__discriminator__ + + discriminator_field_name: str | None = None + + for annotation in meta_annotations: + if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None: + discriminator_field_name = annotation.discriminator + break + + if not discriminator_field_name: + return None + + mapping: dict[str, type] = {} + discriminator_alias: str | None = None + + for variant in get_args(union): + variant = strip_annotated_type(variant) + if is_basemodel_type(variant): + if PYDANTIC_V2: + field = _extract_field_schema_pv2(variant, discriminator_field_name) + if not field: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field.get("serialization_alias") + + field_schema = field["schema"] + + if field_schema["type"] == "literal": + for entry in field_schema["expected"]: + if isinstance(entry, str): + mapping[entry] = variant + else: + field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] + if not field_info: + continue + + # Note: if one variant defines an alias then they all should + discriminator_alias = field_info.alias + + if field_info.annotation and is_literal_type(field_info.annotation): + for entry in get_args(field_info.annotation): + if isinstance(entry, str): + mapping[entry] = variant + + if not mapping: + return None + + details = DiscriminatorDetails( + mapping=mapping, + discriminator_field=discriminator_field_name, + discriminator_alias=discriminator_alias, + ) + cast(CachedDiscriminatorType, union).__discriminator__ = details + return details + + +def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None: + schema = model.__pydantic_core_schema__ + if schema["type"] != "model": + return None + + fields_schema = schema["schema"] + if fields_schema["type"] != "model-fields": + return None + + fields_schema = cast("ModelFieldsSchema", fields_schema) + + field = fields_schema["fields"].get(field_name) + if not field: + return None + + return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast] + + def validate_type(*, type_: type[_T], value: object) -> _T: """Strict validation that the given value matches the expected type""" if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): diff --git a/src/openai/_response.py b/src/openai/_response.py index b1e070122f..47f484ef7a 100644 --- a/src/openai/_response.py +++ b/src/openai/_response.py @@ -25,7 +25,7 @@ import pydantic from ._types import NoneType -from ._utils import is_given, extract_type_var_from_base +from ._utils import is_given, extract_type_arg, is_annotated_type, extract_type_var_from_base from ._models import BaseModel, is_basemodel from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type @@ -121,6 +121,10 @@ def __repr__(self) -> str: ) def _parse(self, *, to: type[_T] | None = None) -> R | _T: + # unwrap `Annotated[T, ...]` -> `T` + if to and is_annotated_type(to): + to = extract_type_arg(to, 0) + if self._is_sse_stream: if to: if not is_stream_class_type(to): @@ -162,6 +166,11 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: ) cast_to = to if to is not None else self._cast_to + + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(cast_to): + cast_to = extract_type_arg(cast_to, 0) + if cast_to is NoneType: return cast(R, None) @@ -172,6 +181,12 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T: if cast_to == bytes: return cast(R, response.content) + if cast_to == int: + return cast(R, int(response.text)) + + if cast_to == float: + return cast(R, float(response.text)) + origin = get_origin(cast_to) or cast_to # handle the legacy binary response case @@ -277,6 +292,8 @@ class MyModel(BaseModel): - `list` - `Union` - `str` + - `int` + - `float` - `httpx.Response` """ cache_key = to if to is not None else self._cast_to @@ -626,7 +643,7 @@ def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseCo @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "stream" kwargs["extra_headers"] = extra_headers @@ -647,7 +664,7 @@ def async_to_streamed_response_wrapper( @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "stream" kwargs["extra_headers"] = extra_headers @@ -671,7 +688,7 @@ def to_custom_streamed_response_wrapper( @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "stream" extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls @@ -696,7 +713,7 @@ def async_to_custom_streamed_response_wrapper( @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "stream" extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls @@ -716,7 +733,7 @@ def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]] @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "raw" kwargs["extra_headers"] = extra_headers @@ -733,7 +750,7 @@ def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P @functools.wraps(func) async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "raw" kwargs["extra_headers"] = extra_headers @@ -755,7 +772,7 @@ def to_custom_raw_response_wrapper( @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "raw" extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls @@ -778,7 +795,7 @@ def async_to_custom_raw_response_wrapper( @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]: - extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers: dict[str, Any] = {**(cast(Any, kwargs.get("extra_headers")) or {})} extra_headers[RAW_RESPONSE_HEADER] = "raw" extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py index 74878fd0a0..41ed11074f 100644 --- a/src/openai/_streaming.py +++ b/src/openai/_streaming.py @@ -5,7 +5,7 @@ import inspect from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast -from typing_extensions import Self, TypeGuard, override, get_origin +from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable import httpx @@ -24,6 +24,8 @@ class Stream(Generic[_T]): response: httpx.Response + _decoder: SSEDecoder | SSEBytesDecoder + def __init__( self, *, @@ -34,7 +36,7 @@ def __init__( self.response = response self._cast_to = cast_to self._client = client - self._decoder = SSEDecoder() + self._decoder = client._make_sse_decoder() self._iterator = self.__stream__() def __next__(self) -> _T: @@ -45,7 +47,10 @@ def __iter__(self) -> Iterator[_T]: yield item def _iter_events(self) -> Iterator[ServerSentEvent]: - yield from self._decoder.iter(self.response.iter_lines()) + if isinstance(self._decoder, SSEBytesDecoder): + yield from self._decoder.iter_bytes(self.response.iter_bytes()) + else: + yield from self._decoder.iter(self.response.iter_lines()) def __stream__(self) -> Iterator[_T]: cast_to = cast(Any, self._cast_to) @@ -60,8 +65,15 @@ def __stream__(self) -> Iterator[_T]: if sse.event is None: data = sse.json() if is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + raise APIError( - message="An error occurred during streaming", + message=message, request=self.response.request, body=data["error"], ) @@ -97,6 +109,8 @@ class AsyncStream(Generic[_T]): response: httpx.Response + _decoder: SSEDecoder | SSEBytesDecoder + def __init__( self, *, @@ -107,7 +121,7 @@ def __init__( self.response = response self._cast_to = cast_to self._client = client - self._decoder = SSEDecoder() + self._decoder = client._make_sse_decoder() self._iterator = self.__stream__() async def __anext__(self) -> _T: @@ -118,8 +132,12 @@ async def __aiter__(self) -> AsyncIterator[_T]: yield item async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: - async for sse in self._decoder.aiter(self.response.aiter_lines()): - yield sse + if isinstance(self._decoder, SSEBytesDecoder): + async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): + yield sse + else: + async for sse in self._decoder.aiter(self.response.aiter_lines()): + yield sse async def __stream__(self) -> AsyncIterator[_T]: cast_to = cast(Any, self._cast_to) @@ -134,8 +152,15 @@ async def __stream__(self) -> AsyncIterator[_T]: if sse.event is None: data = sse.json() if is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + raise APIError( - message="An error occurred during streaming", + message=message, request=self.response.request, body=data["error"], ) @@ -284,6 +309,17 @@ def decode(self, line: str) -> ServerSentEvent | None: return None +@runtime_checkable +class SSEBytesDecoder(Protocol): + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: + """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" + ... + + def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: + """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered""" + ... + + def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]: """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" origin = get_origin(typ) or typ diff --git a/src/openai/_types.py b/src/openai/_types.py index b5bf8f8af0..de9b1dd48b 100644 --- a/src/openai/_types.py +++ b/src/openai/_types.py @@ -41,8 +41,10 @@ ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]] ProxiesTypes = Union[str, Proxy, ProxiesDict] if TYPE_CHECKING: + Base64FileInput = Union[IO[bytes], PathLike[str]] FileContent = Union[IO[bytes], bytes, PathLike[str]] else: + Base64FileInput = Union[IO[bytes], PathLike] FileContent = Union[IO[bytes], bytes, PathLike] # PathLike is not subscriptable in Python 3.8. FileTypes = Union[ # file (or bytes) diff --git a/src/openai/_utils/__init__.py b/src/openai/_utils/__init__.py index b5790a879f..5697894192 100644 --- a/src/openai/_utils/__init__.py +++ b/src/openai/_utils/__init__.py @@ -44,5 +44,7 @@ from ._transform import ( PropertyInfo as PropertyInfo, transform as transform, + async_transform as async_transform, maybe_transform as maybe_transform, + async_maybe_transform as async_maybe_transform, ) diff --git a/src/openai/_utils/_transform.py b/src/openai/_utils/_transform.py index 2cb7726c73..47e262a515 100644 --- a/src/openai/_utils/_transform.py +++ b/src/openai/_utils/_transform.py @@ -1,9 +1,13 @@ from __future__ import annotations +import io +import base64 +import pathlib from typing import Any, Mapping, TypeVar, cast from datetime import date, datetime from typing_extensions import Literal, get_args, override, get_type_hints +import anyio import pydantic from ._utils import ( @@ -11,6 +15,7 @@ is_mapping, is_iterable, ) +from .._files import is_base64_file_input from ._typing import ( is_list_type, is_union_type, @@ -29,7 +34,7 @@ # TODO: ensure works correctly with forward references in all cases -PropertyFormat = Literal["iso8601", "custom"] +PropertyFormat = Literal["iso8601", "base64", "custom"] class PropertyInfo: @@ -46,6 +51,7 @@ class MyParams(TypedDict): alias: str | None format: PropertyFormat | None format_template: str | None + discriminator: str | None def __init__( self, @@ -53,14 +59,16 @@ def __init__( alias: str | None = None, format: PropertyFormat | None = None, format_template: str | None = None, + discriminator: str | None = None, ) -> None: self.alias = alias self.format = format self.format_template = format_template + self.discriminator = discriminator @override def __repr__(self) -> str: - return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')" + return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" def maybe_transform( @@ -180,11 +188,7 @@ def _transform_recursive( if isinstance(data, pydantic.BaseModel): return model_dump(data, exclude_unset=True) - return _transform_value(data, annotation) - - -def _transform_value(data: object, type_: type) -> object: - annotated_type = _get_annotated_type(type_) + annotated_type = _get_annotated_type(annotation) if annotated_type is None: return data @@ -205,6 +209,22 @@ def _format_data(data: object, format_: PropertyFormat, format_template: str | N if format_ == "custom" and format_template is not None: return data.strftime(format_template) + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = data.read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + return data @@ -222,3 +242,141 @@ def _transform_typeddict( else: result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_) return result + + +async def async_maybe_transform( + data: object, + expected_type: object, +) -> Any | None: + """Wrapper over `async_transform()` that allows `None` to be passed. + + See `async_transform()` for more details. + """ + if data is None: + return None + return await async_transform(data, expected_type) + + +async def async_transform( + data: _T, + expected_type: object, +) -> _T: + """Transform dictionaries based off of type information from the given type, for example: + + ```py + class Params(TypedDict, total=False): + card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]] + + + transformed = transform({"card_id": ""}, Params) + # {'cardID': ''} + ``` + + Any keys / data that does not have type information given will be included as is. + + It should be noted that the transformations that this function does are not represented in the type system. + """ + transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type)) + return cast(_T, transformed) + + +async def _async_transform_recursive( + data: object, + *, + annotation: type, + inner_type: type | None = None, +) -> object: + """Transform the given data against the expected type. + + Args: + annotation: The direct type annotation given to the particular piece of data. + This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc + + inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type + is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in + the list can be transformed using the metadata from the container type. + + Defaults to the same value as the `annotation` argument. + """ + if inner_type is None: + inner_type = annotation + + stripped_type = strip_annotated_type(inner_type) + if is_typeddict(stripped_type) and is_mapping(data): + return await _async_transform_typeddict(data, stripped_type) + + if ( + # List[T] + (is_list_type(stripped_type) and is_list(data)) + # Iterable[T] + or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + ): + inner_type = extract_type_arg(stripped_type, 0) + return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] + + if is_union_type(stripped_type): + # For union types we run the transformation against all subtypes to ensure that everything is transformed. + # + # TODO: there may be edge cases where the same normalized field name will transform to two different names + # in different subtypes. + for subtype in get_args(stripped_type): + data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype) + return data + + if isinstance(data, pydantic.BaseModel): + return model_dump(data, exclude_unset=True) + + annotated_type = _get_annotated_type(annotation) + if annotated_type is None: + return data + + # ignore the first argument as it is the actual type + annotations = get_args(annotated_type)[1:] + for annotation in annotations: + if isinstance(annotation, PropertyInfo) and annotation.format is not None: + return await _async_format_data(data, annotation.format, annotation.format_template) + + return data + + +async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object: + if isinstance(data, (date, datetime)): + if format_ == "iso8601": + return data.isoformat() + + if format_ == "custom" and format_template is not None: + return data.strftime(format_template) + + if format_ == "base64" and is_base64_file_input(data): + binary: str | bytes | None = None + + if isinstance(data, pathlib.Path): + binary = await anyio.Path(data).read_bytes() + elif isinstance(data, io.IOBase): + binary = data.read() + + if isinstance(binary, str): # type: ignore[unreachable] + binary = binary.encode() + + if not isinstance(binary, bytes): + raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}") + + return base64.b64encode(binary).decode("ascii") + + return data + + +async def _async_transform_typeddict( + data: Mapping[str, object], + expected_type: type, +) -> Mapping[str, object]: + result: dict[str, object] = {} + annotations = get_type_hints(expected_type, include_extras=True) + for key, value in data.items(): + type_ = annotations.get(key) + if type_ is None: + # we do not have a type annotation for this field, leave it as is + result[key] = value + else: + result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) + return result diff --git a/src/openai/_version.py b/src/openai/_version.py index 503a06141f..4c59f5e629 100644 --- a/src/openai/_version.py +++ b/src/openai/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. __title__ = "openai" -__version__ = "1.13.3" # x-release-please-version +__version__ = "1.13.4" # x-release-please-version diff --git a/src/openai/resources/audio/speech.py b/src/openai/resources/audio/speech.py index a569751ee5..bf4a0245f6 100644 --- a/src/openai/resources/audio/speech.py +++ b/src/openai/resources/audio/speech.py @@ -9,7 +9,10 @@ from ... import _legacy_response from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import maybe_transform +from ..._utils import ( + maybe_transform, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -41,7 +44,7 @@ def create( input: str, model: Union[str, Literal["tts-1", "tts-1-hd"]], voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"], - response_format: Literal["mp3", "opus", "aac", "flac", "pcm", "wav"] | NotGiven = NOT_GIVEN, + response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] | NotGiven = NOT_GIVEN, speed: float | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -65,11 +68,8 @@ def create( available in the [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech/voice-options). - response_format: The format to return audio in. Supported formats are `mp3`, `opus`, `aac`, - `flac`, `pcm`, and `wav`. - - The `pcm` audio format, similar to `wav` but without a header, utilizes a 24kHz - sample rate, mono channel, and 16-bit depth in signed little-endian format. + response_format: The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, + `wav`, and `pcm`. speed: The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default. @@ -117,7 +117,7 @@ async def create( input: str, model: Union[str, Literal["tts-1", "tts-1-hd"]], voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"], - response_format: Literal["mp3", "opus", "aac", "flac", "pcm", "wav"] | NotGiven = NOT_GIVEN, + response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] | NotGiven = NOT_GIVEN, speed: float | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -141,11 +141,8 @@ async def create( available in the [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech/voice-options). - response_format: The format to return audio in. Supported formats are `mp3`, `opus`, `aac`, - `flac`, `pcm`, and `wav`. - - The `pcm` audio format, similar to `wav` but without a header, utilizes a 24kHz - sample rate, mono channel, and 16-bit depth in signed little-endian format. + response_format: The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, + `wav`, and `pcm`. speed: The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default. @@ -161,7 +158,7 @@ async def create( extra_headers = {"Accept": "application/octet-stream", **(extra_headers or {})} return await self._post( "/audio/speech", - body=maybe_transform( + body=await async_maybe_transform( { "input": input, "model": model, diff --git a/src/openai/resources/audio/transcriptions.py b/src/openai/resources/audio/transcriptions.py index 275098ce88..cfd9aae909 100644 --- a/src/openai/resources/audio/transcriptions.py +++ b/src/openai/resources/audio/transcriptions.py @@ -9,7 +9,12 @@ from ... import _legacy_response from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes -from ..._utils import extract_files, maybe_transform, deepcopy_minimal +from ..._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -55,7 +60,8 @@ def create( The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. - model: ID of the model to use. Only `whisper-1` is currently available. + model: ID of the model to use. Only `whisper-1` (which is powered by our open source + Whisper V2 model) is currently available. language: The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will @@ -75,9 +81,11 @@ def create( [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. - timestamp_granularities: The timestamp granularities to populate for this transcription. Any of these - options: `word`, or `segment`. Note: There is no additional latency for segment - timestamps, but generating word timestamps incurs additional latency. + timestamp_granularities: The timestamp granularities to populate for this transcription. + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: There + is no additional latency for segment timestamps, but generating word timestamps + incurs additional latency. extra_headers: Send extra headers @@ -149,7 +157,8 @@ async def create( The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. - model: ID of the model to use. Only `whisper-1` is currently available. + model: ID of the model to use. Only `whisper-1` (which is powered by our open source + Whisper V2 model) is currently available. language: The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will @@ -169,9 +178,11 @@ async def create( [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. - timestamp_granularities: The timestamp granularities to populate for this transcription. Any of these - options: `word`, or `segment`. Note: There is no additional latency for segment - timestamps, but generating word timestamps incurs additional latency. + timestamp_granularities: The timestamp granularities to populate for this transcription. + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: There + is no additional latency for segment timestamps, but generating word timestamps + incurs additional latency. extra_headers: Send extra headers @@ -200,7 +211,7 @@ async def create( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/audio/transcriptions", - body=maybe_transform(body, transcription_create_params.TranscriptionCreateParams), + body=await async_maybe_transform(body, transcription_create_params.TranscriptionCreateParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/openai/resources/audio/translations.py b/src/openai/resources/audio/translations.py index d6cbc75886..6063522237 100644 --- a/src/openai/resources/audio/translations.py +++ b/src/openai/resources/audio/translations.py @@ -9,7 +9,12 @@ from ... import _legacy_response from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes -from ..._utils import extract_files, maybe_transform, deepcopy_minimal +from ..._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -52,7 +57,8 @@ def create( file: The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. - model: ID of the model to use. Only `whisper-1` is currently available. + model: ID of the model to use. Only `whisper-1` (which is powered by our open source + Whisper V2 model) is currently available. prompt: An optional text to guide the model's style or continue a previous audio segment. The @@ -133,7 +139,8 @@ async def create( file: The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. - model: ID of the model to use. Only `whisper-1` is currently available. + model: ID of the model to use. Only `whisper-1` (which is powered by our open source + Whisper V2 model) is currently available. prompt: An optional text to guide the model's style or continue a previous audio segment. The @@ -174,7 +181,7 @@ async def create( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/audio/translations", - body=maybe_transform(body, translation_create_params.TranslationCreateParams), + body=await async_maybe_transform(body, translation_create_params.TranslationCreateParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/openai/resources/beta/assistants/assistants.py b/src/openai/resources/beta/assistants/assistants.py index e926c31642..3aef33c95e 100644 --- a/src/openai/resources/beta/assistants/assistants.py +++ b/src/openai/resources/beta/assistants/assistants.py @@ -17,7 +17,10 @@ AsyncFilesWithStreamingResponse, ) from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ...._utils import maybe_transform +from ...._utils import ( + maybe_transform, + async_maybe_transform, +) from ...._compat import cached_property from ...._resource import SyncAPIResource, AsyncAPIResource from ...._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -410,7 +413,7 @@ async def create( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( "/assistants", - body=maybe_transform( + body=await async_maybe_transform( { "model": model, "description": description, @@ -525,7 +528,7 @@ async def update( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( f"/assistants/{assistant_id}", - body=maybe_transform( + body=await async_maybe_transform( { "description": description, "file_ids": file_ids, diff --git a/src/openai/resources/beta/assistants/files.py b/src/openai/resources/beta/assistants/files.py index c21465036a..8d5657666c 100644 --- a/src/openai/resources/beta/assistants/files.py +++ b/src/openai/resources/beta/assistants/files.py @@ -8,7 +8,10 @@ from .... import _legacy_response from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ...._utils import maybe_transform +from ...._utils import ( + maybe_transform, + async_maybe_transform, +) from ...._compat import cached_property from ...._resource import SyncAPIResource, AsyncAPIResource from ...._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -259,7 +262,7 @@ async def create( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( f"/assistants/{assistant_id}/files", - body=maybe_transform({"file_id": file_id}, file_create_params.FileCreateParams), + body=await async_maybe_transform({"file_id": file_id}, file_create_params.FileCreateParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/openai/resources/beta/threads/messages/messages.py b/src/openai/resources/beta/threads/messages/messages.py index c95cdd5d00..2c0994d1f2 100644 --- a/src/openai/resources/beta/threads/messages/messages.py +++ b/src/openai/resources/beta/threads/messages/messages.py @@ -17,7 +17,10 @@ AsyncFilesWithStreamingResponse, ) from ....._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ....._utils import maybe_transform +from ....._utils import ( + maybe_transform, + async_maybe_transform, +) from ....._compat import cached_property from ....._resource import SyncAPIResource, AsyncAPIResource from ....._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -315,7 +318,7 @@ async def create( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( f"/threads/{thread_id}/messages", - body=maybe_transform( + body=await async_maybe_transform( { "content": content, "role": role, @@ -404,7 +407,7 @@ async def update( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( f"/threads/{thread_id}/messages/{message_id}", - body=maybe_transform({"metadata": metadata}, message_update_params.MessageUpdateParams), + body=await async_maybe_transform({"metadata": metadata}, message_update_params.MessageUpdateParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), diff --git a/src/openai/resources/beta/threads/runs/runs.py b/src/openai/resources/beta/threads/runs/runs.py index 9b18336010..62cfa6b742 100644 --- a/src/openai/resources/beta/threads/runs/runs.py +++ b/src/openai/resources/beta/threads/runs/runs.py @@ -17,7 +17,10 @@ AsyncStepsWithStreamingResponse, ) from ....._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ....._utils import maybe_transform +from ....._utils import ( + maybe_transform, + async_maybe_transform, +) from ....._compat import cached_property from ....._resource import SyncAPIResource, AsyncAPIResource from ....._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -430,7 +433,7 @@ async def create( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( f"/threads/{thread_id}/runs", - body=maybe_transform( + body=await async_maybe_transform( { "assistant_id": assistant_id, "additional_instructions": additional_instructions, @@ -521,7 +524,7 @@ async def update( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( f"/threads/{thread_id}/runs/{run_id}", - body=maybe_transform({"metadata": metadata}, run_update_params.RunUpdateParams), + body=await async_maybe_transform({"metadata": metadata}, run_update_params.RunUpdateParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -669,7 +672,7 @@ async def submit_tool_outputs( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( f"/threads/{thread_id}/runs/{run_id}/submit_tool_outputs", - body=maybe_transform( + body=await async_maybe_transform( {"tool_outputs": tool_outputs}, run_submit_tool_outputs_params.RunSubmitToolOutputsParams ), options=make_request_options( diff --git a/src/openai/resources/beta/threads/threads.py b/src/openai/resources/beta/threads/threads.py index dd079ac533..cc0e1c0959 100644 --- a/src/openai/resources/beta/threads/threads.py +++ b/src/openai/resources/beta/threads/threads.py @@ -24,7 +24,10 @@ AsyncMessagesWithStreamingResponse, ) from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ...._utils import maybe_transform +from ...._utils import ( + maybe_transform, + async_maybe_transform, +) from .runs.runs import Runs, AsyncRuns from ...._compat import cached_property from ...._resource import SyncAPIResource, AsyncAPIResource @@ -342,7 +345,7 @@ async def create( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( "/threads", - body=maybe_transform( + body=await async_maybe_transform( { "messages": messages, "metadata": metadata, @@ -423,7 +426,7 @@ async def update( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( f"/threads/{thread_id}", - body=maybe_transform({"metadata": metadata}, thread_update_params.ThreadUpdateParams), + body=await async_maybe_transform({"metadata": metadata}, thread_update_params.ThreadUpdateParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -517,7 +520,7 @@ async def create_and_run( extra_headers = {"OpenAI-Beta": "assistants=v1", **(extra_headers or {})} return await self._post( "/threads/runs", - body=maybe_transform( + body=await async_maybe_transform( { "assistant_id": assistant_id, "instructions": instructions, diff --git a/src/openai/resources/chat/completions.py b/src/openai/resources/chat/completions.py index 0011d75e6e..20ea4cffbb 100644 --- a/src/openai/resources/chat/completions.py +++ b/src/openai/resources/chat/completions.py @@ -9,7 +9,11 @@ from ... import _legacy_response from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import required_args, maybe_transform +from ..._utils import ( + required_args, + maybe_transform, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -204,9 +208,9 @@ def create( tool. Use this to provide a list of functions the model may generate JSON inputs for. - top_logprobs: An integer between 0 and 5 specifying the number of most likely tokens to return - at each token position, each with an associated log probability. `logprobs` must - be set to `true` if this parameter is used. + top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 @@ -394,9 +398,9 @@ def create( tool. Use this to provide a list of functions the model may generate JSON inputs for. - top_logprobs: An integer between 0 and 5 specifying the number of most likely tokens to return - at each token position, each with an associated log probability. `logprobs` must - be set to `true` if this parameter is used. + top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 @@ -584,9 +588,9 @@ def create( tool. Use this to provide a list of functions the model may generate JSON inputs for. - top_logprobs: An integer between 0 and 5 specifying the number of most likely tokens to return - at each token position, each with an associated log probability. `logprobs` must - be set to `true` if this parameter is used. + top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 @@ -871,9 +875,9 @@ async def create( tool. Use this to provide a list of functions the model may generate JSON inputs for. - top_logprobs: An integer between 0 and 5 specifying the number of most likely tokens to return - at each token position, each with an associated log probability. `logprobs` must - be set to `true` if this parameter is used. + top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 @@ -1061,9 +1065,9 @@ async def create( tool. Use this to provide a list of functions the model may generate JSON inputs for. - top_logprobs: An integer between 0 and 5 specifying the number of most likely tokens to return - at each token position, each with an associated log probability. `logprobs` must - be set to `true` if this parameter is used. + top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 @@ -1251,9 +1255,9 @@ async def create( tool. Use this to provide a list of functions the model may generate JSON inputs for. - top_logprobs: An integer between 0 and 5 specifying the number of most likely tokens to return - at each token position, each with an associated log probability. `logprobs` must - be set to `true` if this parameter is used. + top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 @@ -1329,7 +1333,7 @@ async def create( ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: return await self._post( "/chat/completions", - body=maybe_transform( + body=await async_maybe_transform( { "messages": messages, "model": model, diff --git a/src/openai/resources/completions.py b/src/openai/resources/completions.py index af2d6e2e51..6d3756f6ba 100644 --- a/src/openai/resources/completions.py +++ b/src/openai/resources/completions.py @@ -10,7 +10,11 @@ from .. import _legacy_response from ..types import Completion, completion_create_params from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from .._utils import required_args, maybe_transform +from .._utils import ( + required_args, + maybe_transform, + async_maybe_transform, +) from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -1019,7 +1023,7 @@ async def create( ) -> Completion | AsyncStream[Completion]: return await self._post( "/completions", - body=maybe_transform( + body=await async_maybe_transform( { "model": model, "prompt": prompt, diff --git a/src/openai/resources/files.py b/src/openai/resources/files.py index 8b2bc4f181..3ea66656b3 100644 --- a/src/openai/resources/files.py +++ b/src/openai/resources/files.py @@ -12,7 +12,12 @@ from .. import _legacy_response from ..types import FileObject, FileDeleted, file_list_params, file_create_params from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes -from .._utils import extract_files, maybe_transform, deepcopy_minimal +from .._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( @@ -374,7 +379,7 @@ async def create( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/files", - body=maybe_transform(body, file_create_params.FileCreateParams), + body=await async_maybe_transform(body, file_create_params.FileCreateParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout diff --git a/src/openai/resources/fine_tuning/jobs.py b/src/openai/resources/fine_tuning/jobs.py index 6b59932982..8338de12c4 100644 --- a/src/openai/resources/fine_tuning/jobs.py +++ b/src/openai/resources/fine_tuning/jobs.py @@ -9,7 +9,10 @@ from ... import _legacy_response from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import maybe_transform +from ..._utils import ( + maybe_transform, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -369,7 +372,7 @@ async def create( """ return await self._post( "/fine_tuning/jobs", - body=maybe_transform( + body=await async_maybe_transform( { "model": model, "training_file": training_file, diff --git a/src/openai/resources/images.py b/src/openai/resources/images.py index 91530e47ca..f5bbdbc338 100644 --- a/src/openai/resources/images.py +++ b/src/openai/resources/images.py @@ -15,7 +15,12 @@ image_create_variation_params, ) from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes -from .._utils import extract_files, maybe_transform, deepcopy_minimal +from .._utils import ( + extract_files, + maybe_transform, + deepcopy_minimal, + async_maybe_transform, +) from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -65,7 +70,8 @@ def create_variation( `n=1` is supported. response_format: The format in which the generated images are returned. Must be one of `url` or - `b64_json`. + `b64_json`. URLs are only valid for 60 minutes after the image has been + generated. size: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. @@ -146,7 +152,8 @@ def edit( n: The number of images to generate. Must be between 1 and 10. response_format: The format in which the generated images are returned. Must be one of `url` or - `b64_json`. + `b64_json`. URLs are only valid for 60 minutes after the image has been + generated. size: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. @@ -226,7 +233,8 @@ def generate( for `dall-e-3`. response_format: The format in which the generated images are returned. Must be one of `url` or - `b64_json`. + `b64_json`. URLs are only valid for 60 minutes after the image has been + generated. size: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or @@ -310,7 +318,8 @@ async def create_variation( `n=1` is supported. response_format: The format in which the generated images are returned. Must be one of `url` or - `b64_json`. + `b64_json`. URLs are only valid for 60 minutes after the image has been + generated. size: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. @@ -345,7 +354,7 @@ async def create_variation( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/images/variations", - body=maybe_transform(body, image_create_variation_params.ImageCreateVariationParams), + body=await async_maybe_transform(body, image_create_variation_params.ImageCreateVariationParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -391,7 +400,8 @@ async def edit( n: The number of images to generate. Must be between 1 and 10. response_format: The format in which the generated images are returned. Must be one of `url` or - `b64_json`. + `b64_json`. URLs are only valid for 60 minutes after the image has been + generated. size: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. @@ -428,7 +438,7 @@ async def edit( extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} return await self._post( "/images/edits", - body=maybe_transform(body, image_edit_params.ImageEditParams), + body=await async_maybe_transform(body, image_edit_params.ImageEditParams), files=files, options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -471,7 +481,8 @@ async def generate( for `dall-e-3`. response_format: The format in which the generated images are returned. Must be one of `url` or - `b64_json`. + `b64_json`. URLs are only valid for 60 minutes after the image has been + generated. size: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or @@ -496,7 +507,7 @@ async def generate( """ return await self._post( "/images/generations", - body=maybe_transform( + body=await async_maybe_transform( { "prompt": prompt, "model": model, diff --git a/src/openai/resources/moderations.py b/src/openai/resources/moderations.py index 540d089071..ac5ca1b64b 100644 --- a/src/openai/resources/moderations.py +++ b/src/openai/resources/moderations.py @@ -10,7 +10,10 @@ from .. import _legacy_response from ..types import ModerationCreateResponse, moderation_create_params from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from .._utils import maybe_transform +from .._utils import ( + maybe_transform, + async_maybe_transform, +) from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper @@ -43,7 +46,7 @@ def create( timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ModerationCreateResponse: """ - Classifies if text violates OpenAI's Content Policy + Classifies if text is potentially harmful. Args: input: The input text to classify @@ -103,7 +106,7 @@ async def create( timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> ModerationCreateResponse: """ - Classifies if text violates OpenAI's Content Policy + Classifies if text is potentially harmful. Args: input: The input text to classify @@ -127,7 +130,7 @@ async def create( """ return await self._post( "/moderations", - body=maybe_transform( + body=await async_maybe_transform( { "input": input, "model": model, diff --git a/src/openai/types/audio/speech_create_params.py b/src/openai/types/audio/speech_create_params.py index 00f862272e..0078a9d03a 100644 --- a/src/openai/types/audio/speech_create_params.py +++ b/src/openai/types/audio/speech_create_params.py @@ -26,13 +26,10 @@ class SpeechCreateParams(TypedDict, total=False): [Text to speech guide](https://platform.openai.com/docs/guides/text-to-speech/voice-options). """ - response_format: Literal["mp3", "opus", "aac", "flac", "pcm", "wav"] - """The format to return audio in. + response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] + """The format to audio in. - Supported formats are `mp3`, `opus`, `aac`, `flac`, `pcm`, and `wav`. - - The `pcm` audio format, similar to `wav` but without a header, utilizes a 24kHz - sample rate, mono channel, and 16-bit depth in signed little-endian format. + Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`. """ speed: float diff --git a/src/openai/types/audio/transcription.py b/src/openai/types/audio/transcription.py index d2274faa0e..6532611731 100644 --- a/src/openai/types/audio/transcription.py +++ b/src/openai/types/audio/transcription.py @@ -7,3 +7,4 @@ class Transcription(BaseModel): text: str + """The transcribed text.""" diff --git a/src/openai/types/audio/transcription_create_params.py b/src/openai/types/audio/transcription_create_params.py index 5a90822144..4164a594cc 100644 --- a/src/openai/types/audio/transcription_create_params.py +++ b/src/openai/types/audio/transcription_create_params.py @@ -18,7 +18,11 @@ class TranscriptionCreateParams(TypedDict, total=False): """ model: Required[Union[str, Literal["whisper-1"]]] - """ID of the model to use. Only `whisper-1` is currently available.""" + """ID of the model to use. + + Only `whisper-1` (which is powered by our open source Whisper V2 model) is + currently available. + """ language: str """The language of the input audio. @@ -54,7 +58,8 @@ class TranscriptionCreateParams(TypedDict, total=False): timestamp_granularities: List[Literal["word", "segment"]] """The timestamp granularities to populate for this transcription. - Any of these options: `word`, or `segment`. Note: There is no additional latency - for segment timestamps, but generating word timestamps incurs additional - latency. + `response_format` must be set `verbose_json` to use timestamp granularities. + Either or both of these options are supported: `word`, or `segment`. Note: There + is no additional latency for segment timestamps, but generating word timestamps + incurs additional latency. """ diff --git a/src/openai/types/audio/translation_create_params.py b/src/openai/types/audio/translation_create_params.py index d3cb4b9e63..1ae312da49 100644 --- a/src/openai/types/audio/translation_create_params.py +++ b/src/openai/types/audio/translation_create_params.py @@ -18,7 +18,11 @@ class TranslationCreateParams(TypedDict, total=False): """ model: Required[Union[str, Literal["whisper-1"]]] - """ID of the model to use. Only `whisper-1` is currently available.""" + """ID of the model to use. + + Only `whisper-1` (which is powered by our open source Whisper V2 model) is + currently available. + """ prompt: str """An optional text to guide the model's style or continue a previous audio diff --git a/src/openai/types/beta/assistant.py b/src/openai/types/beta/assistant.py index 7a29984b50..7ba50652aa 100644 --- a/src/openai/types/beta/assistant.py +++ b/src/openai/types/beta/assistant.py @@ -1,9 +1,10 @@ # File generated from our OpenAPI spec by Stainless. from typing import List, Union, Optional -from typing_extensions import Literal +from typing_extensions import Literal, Annotated from ..shared import FunctionDefinition +from ..._utils import PropertyInfo from ..._models import BaseModel __all__ = ["Assistant", "Tool", "ToolCodeInterpreter", "ToolRetrieval", "ToolFunction"] @@ -26,7 +27,7 @@ class ToolFunction(BaseModel): """The type of tool being defined: `function`""" -Tool = Union[ToolCodeInterpreter, ToolRetrieval, ToolFunction] +Tool = Annotated[Union[ToolCodeInterpreter, ToolRetrieval, ToolFunction], PropertyInfo(discriminator="type")] class Assistant(BaseModel): diff --git a/src/openai/types/beta/threads/message_content_text.py b/src/openai/types/beta/threads/message_content_text.py index b529a384c6..dd05ff96ca 100644 --- a/src/openai/types/beta/threads/message_content_text.py +++ b/src/openai/types/beta/threads/message_content_text.py @@ -1,8 +1,9 @@ # File generated from our OpenAPI spec by Stainless. from typing import List, Union -from typing_extensions import Literal +from typing_extensions import Literal, Annotated +from ...._utils import PropertyInfo from ...._models import BaseModel __all__ = [ @@ -57,7 +58,9 @@ class TextAnnotationFilePath(BaseModel): """Always `file_path`.""" -TextAnnotation = Union[TextAnnotationFileCitation, TextAnnotationFilePath] +TextAnnotation = Annotated[ + Union[TextAnnotationFileCitation, TextAnnotationFilePath], PropertyInfo(discriminator="type") +] class Text(BaseModel): diff --git a/src/openai/types/beta/threads/run.py b/src/openai/types/beta/threads/run.py index 79e4f6a444..38625d3781 100644 --- a/src/openai/types/beta/threads/run.py +++ b/src/openai/types/beta/threads/run.py @@ -22,8 +22,8 @@ class LastError(BaseModel): - code: Literal["server_error", "rate_limit_exceeded"] - """One of `server_error` or `rate_limit_exceeded`.""" + code: Literal["server_error", "rate_limit_exceeded", "invalid_prompt"] + """One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`.""" message: str """A human-readable description of the error.""" diff --git a/src/openai/types/beta/threads/runs/code_tool_call.py b/src/openai/types/beta/threads/runs/code_tool_call.py index f808005ecb..0de47b379b 100644 --- a/src/openai/types/beta/threads/runs/code_tool_call.py +++ b/src/openai/types/beta/threads/runs/code_tool_call.py @@ -1,8 +1,9 @@ # File generated from our OpenAPI spec by Stainless. from typing import List, Union -from typing_extensions import Literal +from typing_extensions import Literal, Annotated +from ....._utils import PropertyInfo from ....._models import BaseModel __all__ = [ @@ -38,7 +39,9 @@ class CodeInterpreterOutputImage(BaseModel): """Always `image`.""" -CodeInterpreterOutput = Union[CodeInterpreterOutputLogs, CodeInterpreterOutputImage] +CodeInterpreterOutput = Annotated[ + Union[CodeInterpreterOutputLogs, CodeInterpreterOutputImage], PropertyInfo(discriminator="type") +] class CodeInterpreter(BaseModel): diff --git a/src/openai/types/beta/threads/runs/run_step.py b/src/openai/types/beta/threads/runs/run_step.py index 01aab8e9a6..899883ac2d 100644 --- a/src/openai/types/beta/threads/runs/run_step.py +++ b/src/openai/types/beta/threads/runs/run_step.py @@ -1,8 +1,9 @@ # File generated from our OpenAPI spec by Stainless. from typing import Union, Optional -from typing_extensions import Literal +from typing_extensions import Literal, Annotated +from ....._utils import PropertyInfo from ....._models import BaseModel from .tool_calls_step_details import ToolCallsStepDetails from .message_creation_step_details import MessageCreationStepDetails @@ -18,7 +19,7 @@ class LastError(BaseModel): """A human-readable description of the error.""" -StepDetails = Union[MessageCreationStepDetails, ToolCallsStepDetails] +StepDetails = Annotated[Union[MessageCreationStepDetails, ToolCallsStepDetails], PropertyInfo(discriminator="type")] class Usage(BaseModel): diff --git a/src/openai/types/beta/threads/runs/tool_calls_step_details.py b/src/openai/types/beta/threads/runs/tool_calls_step_details.py index 80eb90bf66..b1b5a72bee 100644 --- a/src/openai/types/beta/threads/runs/tool_calls_step_details.py +++ b/src/openai/types/beta/threads/runs/tool_calls_step_details.py @@ -1,8 +1,9 @@ # File generated from our OpenAPI spec by Stainless. from typing import List, Union -from typing_extensions import Literal +from typing_extensions import Literal, Annotated +from ....._utils import PropertyInfo from ....._models import BaseModel from .code_tool_call import CodeToolCall from .function_tool_call import FunctionToolCall @@ -10,7 +11,7 @@ __all__ = ["ToolCallsStepDetails", "ToolCall"] -ToolCall = Union[CodeToolCall, RetrievalToolCall, FunctionToolCall] +ToolCall = Annotated[Union[CodeToolCall, RetrievalToolCall, FunctionToolCall], PropertyInfo(discriminator="type")] class ToolCallsStepDetails(BaseModel): diff --git a/src/openai/types/beta/threads/thread_message.py b/src/openai/types/beta/threads/thread_message.py index 25b3a199f7..6ed5da1401 100644 --- a/src/openai/types/beta/threads/thread_message.py +++ b/src/openai/types/beta/threads/thread_message.py @@ -1,15 +1,16 @@ # File generated from our OpenAPI spec by Stainless. from typing import List, Union, Optional -from typing_extensions import Literal +from typing_extensions import Literal, Annotated +from ...._utils import PropertyInfo from ...._models import BaseModel from .message_content_text import MessageContentText from .message_content_image_file import MessageContentImageFile __all__ = ["ThreadMessage", "Content"] -Content = Union[MessageContentImageFile, MessageContentText] +Content = Annotated[Union[MessageContentImageFile, MessageContentText], PropertyInfo(discriminator="type")] class ThreadMessage(BaseModel): diff --git a/src/openai/types/chat/chat_completion_token_logprob.py b/src/openai/types/chat/chat_completion_token_logprob.py index 728845fb33..076ffb680c 100644 --- a/src/openai/types/chat/chat_completion_token_logprob.py +++ b/src/openai/types/chat/chat_completion_token_logprob.py @@ -20,7 +20,12 @@ class TopLogprob(BaseModel): """ logprob: float - """The log probability of this token.""" + """The log probability of this token, if it is within the top 20 most likely + tokens. + + Otherwise, the value `-9999.0` is used to signify that the token is very + unlikely. + """ class ChatCompletionTokenLogprob(BaseModel): @@ -36,7 +41,12 @@ class ChatCompletionTokenLogprob(BaseModel): """ logprob: float - """The log probability of this token.""" + """The log probability of this token, if it is within the top 20 most likely + tokens. + + Otherwise, the value `-9999.0` is used to signify that the token is very + unlikely. + """ top_logprobs: List[TopLogprob] """List of the most likely tokens and their log probability, at this token diff --git a/src/openai/types/chat/completion_create_params.py b/src/openai/types/chat/completion_create_params.py index e02a81bc51..9afbacb874 100644 --- a/src/openai/types/chat/completion_create_params.py +++ b/src/openai/types/chat/completion_create_params.py @@ -195,9 +195,9 @@ class CompletionCreateParamsBase(TypedDict, total=False): top_logprobs: Optional[int] """ - An integer between 0 and 5 specifying the number of most likely tokens to return - at each token position, each with an associated log probability. `logprobs` must - be set to `true` if this parameter is used. + An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. """ top_p: Optional[float] diff --git a/src/openai/types/image_create_variation_params.py b/src/openai/types/image_create_variation_params.py index 7b015fc176..5714f97fa9 100644 --- a/src/openai/types/image_create_variation_params.py +++ b/src/openai/types/image_create_variation_params.py @@ -32,7 +32,8 @@ class ImageCreateVariationParams(TypedDict, total=False): response_format: Optional[Literal["url", "b64_json"]] """The format in which the generated images are returned. - Must be one of `url` or `b64_json`. + Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the + image has been generated. """ size: Optional[Literal["256x256", "512x512", "1024x1024"]] diff --git a/src/openai/types/image_edit_params.py b/src/openai/types/image_edit_params.py index 043885cc38..751ec4fe7a 100644 --- a/src/openai/types/image_edit_params.py +++ b/src/openai/types/image_edit_params.py @@ -43,7 +43,8 @@ class ImageEditParams(TypedDict, total=False): response_format: Optional[Literal["url", "b64_json"]] """The format in which the generated images are returned. - Must be one of `url` or `b64_json`. + Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the + image has been generated. """ size: Optional[Literal["256x256", "512x512", "1024x1024"]] diff --git a/src/openai/types/image_generate_params.py b/src/openai/types/image_generate_params.py index 7eca29a7ba..3ff1b979db 100644 --- a/src/openai/types/image_generate_params.py +++ b/src/openai/types/image_generate_params.py @@ -35,7 +35,8 @@ class ImageGenerateParams(TypedDict, total=False): response_format: Optional[Literal["url", "b64_json"]] """The format in which the generated images are returned. - Must be one of `url` or `b64_json`. + Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the + image has been generated. """ size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] diff --git a/src/openai/types/moderation.py b/src/openai/types/moderation.py index 09c9a6058b..1c26ec3367 100644 --- a/src/openai/types/moderation.py +++ b/src/openai/types/moderation.py @@ -114,7 +114,4 @@ class Moderation(BaseModel): """A list of the categories along with their scores as predicted by model.""" flagged: bool - """ - Whether the content violates - [OpenAI's usage policies](/policies/usage-policies). - """ + """Whether any of the below categories are flagged.""" diff --git a/tests/sample_file.txt b/tests/sample_file.txt new file mode 100644 index 0000000000..af5626b4a1 --- /dev/null +++ b/tests/sample_file.txt @@ -0,0 +1 @@ +Hello, world! diff --git a/tests/test_client.py b/tests/test_client.py index 625b822352..a6f936da67 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -292,6 +292,16 @@ def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + async def test_invalid_http_client(self) -> None: + with pytest.raises(TypeError, match="Invalid `http_client` arg"): + async with httpx.AsyncClient() as http_client: + OpenAI( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=cast(Any, http_client), + ) + def test_default_headers_option(self) -> None: client = OpenAI( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} @@ -994,6 +1004,16 @@ async def test_http_client_timeout_option(self) -> None: timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore assert timeout == DEFAULT_TIMEOUT # our default + def test_invalid_http_client(self) -> None: + with pytest.raises(TypeError, match="Invalid `http_client` arg"): + with httpx.Client() as http_client: + AsyncOpenAI( + base_url=base_url, + api_key=api_key, + _strict_response_validation=True, + http_client=cast(Any, http_client), + ) + def test_default_headers_option(self) -> None: client = AsyncOpenAI( base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"} diff --git a/tests/test_legacy_response.py b/tests/test_legacy_response.py index 995250a58c..45025f81d0 100644 --- a/tests/test_legacy_response.py +++ b/tests/test_legacy_response.py @@ -1,4 +1,6 @@ import json +from typing import cast +from typing_extensions import Annotated import httpx import pytest @@ -63,3 +65,20 @@ def test_response_parse_custom_model(client: OpenAI) -> None: obj = response.parse(to=CustomModel) assert obj.foo == "hello!" assert obj.bar == 2 + + +def test_response_parse_annotated_type(client: OpenAI) -> None: + response = LegacyAPIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse( + to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), + ) + assert obj.foo == "hello!" + assert obj.bar == 2 diff --git a/tests/test_models.py b/tests/test_models.py index 713bd2cb1b..d003d32181 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,14 +1,15 @@ import json from typing import Any, Dict, List, Union, Optional, cast from datetime import datetime, timezone -from typing_extensions import Literal +from typing_extensions import Literal, Annotated import pytest import pydantic from pydantic import Field +from openai._utils import PropertyInfo from openai._compat import PYDANTIC_V2, parse_obj, model_dump, model_json -from openai._models import BaseModel +from openai._models import BaseModel, construct_type class BasicModel(BaseModel): @@ -571,3 +572,194 @@ class OurModel(BaseModel): foo: Optional[str] = None takes_pydantic(OurModel()) + + +def test_annotated_types() -> None: + class Model(BaseModel): + value: str + + m = construct_type( + value={"value": "foo"}, + type_=cast(Any, Annotated[Model, "random metadata"]), + ) + assert isinstance(m, Model) + assert m.value == "foo" + + +def test_discriminated_unions_invalid_data() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "a", "data": 100}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, A) + assert m.type == "a" + if PYDANTIC_V2: + assert m.data == 100 # type: ignore[comparison-overlap] + else: + # pydantic v1 automatically converts inputs to strings + # if the expected type is a str + assert m.data == "100" + + +def test_discriminated_unions_unknown_variant() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + m = construct_type( + value={"type": "c", "data": None, "new_thing": "bar"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + + # just chooses the first variant + assert isinstance(m, A) + assert m.type == "c" # type: ignore[comparison-overlap] + assert m.data == None # type: ignore[unreachable] + assert m.new_thing == "bar" + + +def test_discriminated_unions_invalid_data_nested_unions() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + class C(BaseModel): + type: Literal["c"] + + data: bool + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "c", "data": "foo"}, + type_=cast(Any, Annotated[Union[Union[A, B], C], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, C) + assert m.type == "c" + assert m.data == "foo" # type: ignore[comparison-overlap] + + +def test_discriminated_unions_with_aliases_invalid_data() -> None: + class A(BaseModel): + foo_type: Literal["a"] = Field(alias="type") + + data: str + + class B(BaseModel): + foo_type: Literal["b"] = Field(alias="type") + + data: int + + m = construct_type( + value={"type": "b", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]), + ) + assert isinstance(m, B) + assert m.foo_type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + m = construct_type( + value={"type": "a", "data": 100}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="foo_type")]), + ) + assert isinstance(m, A) + assert m.foo_type == "a" + if PYDANTIC_V2: + assert m.data == 100 # type: ignore[comparison-overlap] + else: + # pydantic v1 automatically converts inputs to strings + # if the expected type is a str + assert m.data == "100" + + +def test_discriminated_unions_overlapping_discriminators_invalid_data() -> None: + class A(BaseModel): + type: Literal["a"] + + data: bool + + class B(BaseModel): + type: Literal["a"] + + data: int + + m = construct_type( + value={"type": "a", "data": "foo"}, + type_=cast(Any, Annotated[Union[A, B], PropertyInfo(discriminator="type")]), + ) + assert isinstance(m, B) + assert m.type == "a" + assert m.data == "foo" # type: ignore[comparison-overlap] + + +def test_discriminated_unions_invalid_data_uses_cache() -> None: + class A(BaseModel): + type: Literal["a"] + + data: str + + class B(BaseModel): + type: Literal["b"] + + data: int + + UnionType = cast(Any, Union[A, B]) + + assert not hasattr(UnionType, "__discriminator__") + + m = construct_type( + value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + discriminator = UnionType.__discriminator__ + assert discriminator is not None + + m = construct_type( + value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) + ) + assert isinstance(m, B) + assert m.type == "b" + assert m.data == "foo" # type: ignore[comparison-overlap] + + # if the discriminator details object stays the same between invocations then + # we hit the cache + assert UnionType.__discriminator__ is discriminator diff --git a/tests/test_response.py b/tests/test_response.py index 7c99327b46..af153b67c4 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,5 +1,6 @@ import json -from typing import List +from typing import List, cast +from typing_extensions import Annotated import httpx import pytest @@ -157,3 +158,37 @@ async def test_async_response_parse_custom_model(async_client: AsyncOpenAI) -> N obj = await response.parse(to=CustomModel) assert obj.foo == "hello!" assert obj.bar == 2 + + +def test_response_parse_annotated_type(client: OpenAI) -> None: + response = APIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse( + to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), + ) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +async def test_async_response_parse_annotated_type(async_client: AsyncOpenAI) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse( + to=cast("type[CustomModel]", Annotated[CustomModel, "random metadata"]), + ) + assert obj.foo == "hello!" + assert obj.bar == 2 diff --git a/tests/test_transform.py b/tests/test_transform.py index 6ed67d49a7..0d17e8a972 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,22 +1,50 @@ from __future__ import annotations -from typing import Any, List, Union, Iterable, Optional, cast +import io +import pathlib +from typing import Any, List, Union, TypeVar, Iterable, Optional, cast from datetime import date, datetime from typing_extensions import Required, Annotated, TypedDict import pytest -from openai._utils import PropertyInfo, transform, parse_datetime +from openai._types import Base64FileInput +from openai._utils import ( + PropertyInfo, + transform as _transform, + parse_datetime, + async_transform as _async_transform, +) from openai._compat import PYDANTIC_V2 from openai._models import BaseModel +_T = TypeVar("_T") + +SAMPLE_FILE_PATH = pathlib.Path(__file__).parent.joinpath("sample_file.txt") + + +async def transform( + data: _T, + expected_type: object, + use_async: bool, +) -> _T: + if use_async: + return await _async_transform(data, expected_type=expected_type) + + return _transform(data, expected_type=expected_type) + + +parametrize = pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"]) + class Foo1(TypedDict): foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] -def test_top_level_alias() -> None: - assert transform({"foo_bar": "hello"}, expected_type=Foo1) == {"fooBar": "hello"} +@parametrize +@pytest.mark.asyncio +async def test_top_level_alias(use_async: bool) -> None: + assert await transform({"foo_bar": "hello"}, expected_type=Foo1, use_async=use_async) == {"fooBar": "hello"} class Foo2(TypedDict): @@ -32,9 +60,11 @@ class Baz2(TypedDict): my_baz: Annotated[str, PropertyInfo(alias="myBaz")] -def test_recursive_typeddict() -> None: - assert transform({"bar": {"this_thing": 1}}, Foo2) == {"bar": {"this__thing": 1}} - assert transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2) == {"bar": {"Baz": {"myBaz": "foo"}}} +@parametrize +@pytest.mark.asyncio +async def test_recursive_typeddict(use_async: bool) -> None: + assert await transform({"bar": {"this_thing": 1}}, Foo2, use_async) == {"bar": {"this__thing": 1}} + assert await transform({"bar": {"baz": {"my_baz": "foo"}}}, Foo2, use_async) == {"bar": {"Baz": {"myBaz": "foo"}}} class Foo3(TypedDict): @@ -45,8 +75,10 @@ class Bar3(TypedDict): my_field: Annotated[str, PropertyInfo(alias="myField")] -def test_list_of_typeddict() -> None: - result = transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, expected_type=Foo3) +@parametrize +@pytest.mark.asyncio +async def test_list_of_typeddict(use_async: bool) -> None: + result = await transform({"things": [{"my_field": "foo"}, {"my_field": "foo2"}]}, Foo3, use_async) assert result == {"things": [{"myField": "foo"}, {"myField": "foo2"}]} @@ -62,10 +94,14 @@ class Baz4(TypedDict): foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] -def test_union_of_typeddict() -> None: - assert transform({"foo": {"foo_bar": "bar"}}, Foo4) == {"foo": {"fooBar": "bar"}} - assert transform({"foo": {"foo_baz": "baz"}}, Foo4) == {"foo": {"fooBaz": "baz"}} - assert transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4) == {"foo": {"fooBaz": "baz", "fooBar": "bar"}} +@parametrize +@pytest.mark.asyncio +async def test_union_of_typeddict(use_async: bool) -> None: + assert await transform({"foo": {"foo_bar": "bar"}}, Foo4, use_async) == {"foo": {"fooBar": "bar"}} + assert await transform({"foo": {"foo_baz": "baz"}}, Foo4, use_async) == {"foo": {"fooBaz": "baz"}} + assert await transform({"foo": {"foo_baz": "baz", "foo_bar": "bar"}}, Foo4, use_async) == { + "foo": {"fooBaz": "baz", "fooBar": "bar"} + } class Foo5(TypedDict): @@ -80,9 +116,11 @@ class Baz5(TypedDict): foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] -def test_union_of_list() -> None: - assert transform({"foo": {"foo_bar": "bar"}}, Foo5) == {"FOO": {"fooBar": "bar"}} - assert transform( +@parametrize +@pytest.mark.asyncio +async def test_union_of_list(use_async: bool) -> None: + assert await transform({"foo": {"foo_bar": "bar"}}, Foo5, use_async) == {"FOO": {"fooBar": "bar"}} + assert await transform( { "foo": [ {"foo_baz": "baz"}, @@ -90,6 +128,7 @@ def test_union_of_list() -> None: ] }, Foo5, + use_async, ) == {"FOO": [{"fooBaz": "baz"}, {"fooBaz": "baz"}]} @@ -97,8 +136,10 @@ class Foo6(TypedDict): bar: Annotated[str, PropertyInfo(alias="Bar")] -def test_includes_unknown_keys() -> None: - assert transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6) == { +@parametrize +@pytest.mark.asyncio +async def test_includes_unknown_keys(use_async: bool) -> None: + assert await transform({"bar": "bar", "baz_": {"FOO": 1}}, Foo6, use_async) == { "Bar": "bar", "baz_": {"FOO": 1}, } @@ -113,9 +154,11 @@ class Bar7(TypedDict): foo: str -def test_ignores_invalid_input() -> None: - assert transform({"bar": ""}, Foo7) == {"bAr": ""} - assert transform({"foo": ""}, Foo7) == {"foo": ""} +@parametrize +@pytest.mark.asyncio +async def test_ignores_invalid_input(use_async: bool) -> None: + assert await transform({"bar": ""}, Foo7, use_async) == {"bAr": ""} + assert await transform({"foo": ""}, Foo7, use_async) == {"foo": ""} class DatetimeDict(TypedDict, total=False): @@ -134,52 +177,66 @@ class DateDict(TypedDict, total=False): foo: Annotated[date, PropertyInfo(format="iso8601")] -def test_iso8601_format() -> None: +@parametrize +@pytest.mark.asyncio +async def test_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - assert transform({"foo": dt}, DatetimeDict) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] dt = dt.replace(tzinfo=None) - assert transform({"foo": dt}, DatetimeDict) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] + assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap] - assert transform({"foo": None}, DateDict) == {"foo": None} # type: ignore[comparison-overlap] - assert transform({"foo": date.fromisoformat("2023-02-23")}, DateDict) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap] + assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap] + assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap] -def test_optional_iso8601_format() -> None: +@parametrize +@pytest.mark.asyncio +async def test_optional_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - assert transform({"bar": dt}, DatetimeDict) == {"bar": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform({"bar": dt}, DatetimeDict, use_async) == {"bar": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] - assert transform({"bar": None}, DatetimeDict) == {"bar": None} + assert await transform({"bar": None}, DatetimeDict, use_async) == {"bar": None} -def test_required_iso8601_format() -> None: +@parametrize +@pytest.mark.asyncio +async def test_required_iso8601_format(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - assert transform({"required": dt}, DatetimeDict) == {"required": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap] + assert await transform({"required": dt}, DatetimeDict, use_async) == { + "required": "2023-02-23T14:16:36.337692+00:00" + } # type: ignore[comparison-overlap] - assert transform({"required": None}, DatetimeDict) == {"required": None} + assert await transform({"required": None}, DatetimeDict, use_async) == {"required": None} -def test_union_datetime() -> None: +@parametrize +@pytest.mark.asyncio +async def test_union_datetime(use_async: bool) -> None: dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") - assert transform({"union": dt}, DatetimeDict) == { # type: ignore[comparison-overlap] + assert await transform({"union": dt}, DatetimeDict, use_async) == { # type: ignore[comparison-overlap] "union": "2023-02-23T14:16:36.337692+00:00" } - assert transform({"union": "foo"}, DatetimeDict) == {"union": "foo"} + assert await transform({"union": "foo"}, DatetimeDict, use_async) == {"union": "foo"} -def test_nested_list_iso6801_format() -> None: +@parametrize +@pytest.mark.asyncio +async def test_nested_list_iso6801_format(use_async: bool) -> None: dt1 = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00") dt2 = parse_datetime("2022-01-15T06:34:23Z") - assert transform({"list_": [dt1, dt2]}, DatetimeDict) == { # type: ignore[comparison-overlap] + assert await transform({"list_": [dt1, dt2]}, DatetimeDict, use_async) == { # type: ignore[comparison-overlap] "list_": ["2023-02-23T14:16:36.337692+00:00", "2022-01-15T06:34:23+00:00"] } -def test_datetime_custom_format() -> None: +@parametrize +@pytest.mark.asyncio +async def test_datetime_custom_format(use_async: bool) -> None: dt = parse_datetime("2022-01-15T06:34:23Z") - result = transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")]) + result = await transform(dt, Annotated[datetime, PropertyInfo(format="custom", format_template="%H")], use_async) assert result == "06" # type: ignore[comparison-overlap] @@ -187,47 +244,59 @@ class DateDictWithRequiredAlias(TypedDict, total=False): required_prop: Required[Annotated[date, PropertyInfo(format="iso8601", alias="prop")]] -def test_datetime_with_alias() -> None: - assert transform({"required_prop": None}, DateDictWithRequiredAlias) == {"prop": None} # type: ignore[comparison-overlap] - assert transform({"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias) == { - "prop": "2023-02-23" - } # type: ignore[comparison-overlap] +@parametrize +@pytest.mark.asyncio +async def test_datetime_with_alias(use_async: bool) -> None: + assert await transform({"required_prop": None}, DateDictWithRequiredAlias, use_async) == {"prop": None} # type: ignore[comparison-overlap] + assert await transform( + {"required_prop": date.fromisoformat("2023-02-23")}, DateDictWithRequiredAlias, use_async + ) == {"prop": "2023-02-23"} # type: ignore[comparison-overlap] class MyModel(BaseModel): foo: str -def test_pydantic_model_to_dictionary() -> None: - assert transform(MyModel(foo="hi!"), Any) == {"foo": "hi!"} - assert transform(MyModel.construct(foo="hi!"), Any) == {"foo": "hi!"} +@parametrize +@pytest.mark.asyncio +async def test_pydantic_model_to_dictionary(use_async: bool) -> None: + assert await transform(MyModel(foo="hi!"), Any, use_async) == {"foo": "hi!"} + assert await transform(MyModel.construct(foo="hi!"), Any, use_async) == {"foo": "hi!"} -def test_pydantic_empty_model() -> None: - assert transform(MyModel.construct(), Any) == {} +@parametrize +@pytest.mark.asyncio +async def test_pydantic_empty_model(use_async: bool) -> None: + assert await transform(MyModel.construct(), Any, use_async) == {} -def test_pydantic_unknown_field() -> None: - assert transform(MyModel.construct(my_untyped_field=True), Any) == {"my_untyped_field": True} +@parametrize +@pytest.mark.asyncio +async def test_pydantic_unknown_field(use_async: bool) -> None: + assert await transform(MyModel.construct(my_untyped_field=True), Any, use_async) == {"my_untyped_field": True} -def test_pydantic_mismatched_types() -> None: +@parametrize +@pytest.mark.asyncio +async def test_pydantic_mismatched_types(use_async: bool) -> None: model = MyModel.construct(foo=True) if PYDANTIC_V2: with pytest.warns(UserWarning): - params = transform(model, Any) + params = await transform(model, Any, use_async) else: - params = transform(model, Any) + params = await transform(model, Any, use_async) assert params == {"foo": True} -def test_pydantic_mismatched_object_type() -> None: +@parametrize +@pytest.mark.asyncio +async def test_pydantic_mismatched_object_type(use_async: bool) -> None: model = MyModel.construct(foo=MyModel.construct(hello="world")) if PYDANTIC_V2: with pytest.warns(UserWarning): - params = transform(model, Any) + params = await transform(model, Any, use_async) else: - params = transform(model, Any) + params = await transform(model, Any, use_async) assert params == {"foo": {"hello": "world"}} @@ -235,10 +304,12 @@ class ModelNestedObjects(BaseModel): nested: MyModel -def test_pydantic_nested_objects() -> None: +@parametrize +@pytest.mark.asyncio +async def test_pydantic_nested_objects(use_async: bool) -> None: model = ModelNestedObjects.construct(nested={"foo": "stainless"}) assert isinstance(model.nested, MyModel) - assert transform(model, Any) == {"nested": {"foo": "stainless"}} + assert await transform(model, Any, use_async) == {"nested": {"foo": "stainless"}} class ModelWithDefaultField(BaseModel): @@ -247,24 +318,26 @@ class ModelWithDefaultField(BaseModel): with_str_default: str = "foo" -def test_pydantic_default_field() -> None: +@parametrize +@pytest.mark.asyncio +async def test_pydantic_default_field(use_async: bool) -> None: # should be excluded when defaults are used model = ModelWithDefaultField.construct() assert model.with_none_default is None assert model.with_str_default == "foo" - assert transform(model, Any) == {} + assert await transform(model, Any, use_async) == {} # should be included when the default value is explicitly given model = ModelWithDefaultField.construct(with_none_default=None, with_str_default="foo") assert model.with_none_default is None assert model.with_str_default == "foo" - assert transform(model, Any) == {"with_none_default": None, "with_str_default": "foo"} + assert await transform(model, Any, use_async) == {"with_none_default": None, "with_str_default": "foo"} # should be included when a non-default value is explicitly given model = ModelWithDefaultField.construct(with_none_default="bar", with_str_default="baz") assert model.with_none_default == "bar" assert model.with_str_default == "baz" - assert transform(model, Any) == {"with_none_default": "bar", "with_str_default": "baz"} + assert await transform(model, Any, use_async) == {"with_none_default": "bar", "with_str_default": "baz"} class TypedDictIterableUnion(TypedDict): @@ -279,21 +352,57 @@ class Baz8(TypedDict): foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] -def test_iterable_of_dictionaries() -> None: - assert transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "bar"}]} - assert cast(Any, transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion)) == {"FOO": [{"fooBaz": "bar"}]} +@parametrize +@pytest.mark.asyncio +async def test_iterable_of_dictionaries(use_async: bool) -> None: + assert await transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion, use_async) == { + "FOO": [{"fooBaz": "bar"}] + } + assert cast(Any, await transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion, use_async)) == { + "FOO": [{"fooBaz": "bar"}] + } def my_iter() -> Iterable[Baz8]: yield {"foo_baz": "hello"} yield {"foo_baz": "world"} - assert transform({"foo": my_iter()}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]} + assert await transform({"foo": my_iter()}, TypedDictIterableUnion, use_async) == { + "FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}] + } class TypedDictIterableUnionStr(TypedDict): foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")] -def test_iterable_union_str() -> None: - assert transform({"foo": "bar"}, TypedDictIterableUnionStr) == {"FOO": "bar"} - assert cast(Any, transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]])) == [{"fooBaz": "bar"}] +@parametrize +@pytest.mark.asyncio +async def test_iterable_union_str(use_async: bool) -> None: + assert await transform({"foo": "bar"}, TypedDictIterableUnionStr, use_async) == {"FOO": "bar"} + assert cast(Any, await transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]], use_async)) == [ + {"fooBaz": "bar"} + ] + + +class TypedDictBase64Input(TypedDict): + foo: Annotated[Union[str, Base64FileInput], PropertyInfo(format="base64")] + + +@parametrize +@pytest.mark.asyncio +async def test_base64_file_input(use_async: bool) -> None: + # strings are left as-is + assert await transform({"foo": "bar"}, TypedDictBase64Input, use_async) == {"foo": "bar"} + + # pathlib.Path is automatically converted to base64 + assert await transform({"foo": SAMPLE_FILE_PATH}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQo=" + } # type: ignore[comparison-overlap] + + # io instances are automatically converted to base64 + assert await transform({"foo": io.StringIO("Hello, world!")}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQ==" + } # type: ignore[comparison-overlap] + assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == { + "foo": "SGVsbG8sIHdvcmxkIQ==" + } # type: ignore[comparison-overlap] diff --git a/tests/utils.py b/tests/utils.py index 216b333550..43c3cb5cfe 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,6 +14,8 @@ is_list, is_list_type, is_union_type, + extract_type_arg, + is_annotated_type, ) from openai._compat import PYDANTIC_V2, field_outer_type, get_model_fields from openai._models import BaseModel @@ -49,6 +51,10 @@ def assert_matches_type( path: list[str], allow_none: bool = False, ) -> None: + # unwrap `Annotated[T, ...]` -> `T` + if is_annotated_type(type_): + type_ = extract_type_arg(type_, 0) + if allow_none and value is None: return